Bug 1676916 - Implicit bind group layouts in WebGPU r=jgilbert,webidl,smaug

This change updates and enables Naga to get the
SPIRV shaders parsed, validated, and reflected back into
implicit bind group layouts.
WebGPU examples heavily rely on the implicit layouts now,
and the PR also updates the WebIDL to make that possible.
With the change, we are able to run most of the examples again!

Differential Revision: https://phabricator.services.mozilla.com/D96850
This commit is contained in:
Dzmitry Malyshau 2020-11-13 14:15:49 +00:00
parent c4b40283ab
commit 8f74799ba5
67 changed files with 3182 additions and 1186 deletions

View file

@ -50,7 +50,7 @@ rev = "0917fe780032a6bbb23d71be545f9c1834128d75"
[source."https://github.com/gfx-rs/naga"]
git = "https://github.com/gfx-rs/naga"
replace-with = "vendored-sources"
rev = "aa35110471ee7915e1f4e1de61ea41f2f32f92c4"
rev = "4d4e1cd4cbfad2b81264a7239a336b6ec1346611"
[source."https://github.com/djg/cubeb-pulse-rs"]
git = "https://github.com/djg/cubeb-pulse-rs"

2
Cargo.lock generated
View file

@ -3329,7 +3329,7 @@ checksum = "a2983372caf4480544083767bf2d27defafe32af49ab4df3a0b7fc90793a3664"
[[package]]
name = "naga"
version = "0.2.0"
source = "git+https://github.com/gfx-rs/naga?rev=aa35110471ee7915e1f4e1de61ea41f2f32f92c4#aa35110471ee7915e1f4e1de61ea41f2f32f92c4"
source = "git+https://github.com/gfx-rs/naga?rev=4d4e1cd4cbfad2b81264a7239a336b6ec1346611#4d4e1cd4cbfad2b81264a7239a336b6ec1346611"
dependencies = [
"bitflags",
"fxhash",

View file

@ -13,8 +13,11 @@ namespace webgpu {
GPU_IMPL_CYCLE_COLLECTION(ComputePipeline, mParent)
GPU_IMPL_JS_WRAP(ComputePipeline)
ComputePipeline::ComputePipeline(Device* const aParent, RawId aId)
: ChildOf(aParent), mId(aId) {}
ComputePipeline::ComputePipeline(Device* const aParent, RawId aId,
nsTArray<RawId>&& aImplicitBindGroupLayoutIds)
: ChildOf(aParent),
mImplicitBindGroupLayoutIds(std::move(aImplicitBindGroupLayoutIds)),
mId(aId) {}
ComputePipeline::~ComputePipeline() { Cleanup(); }
@ -28,5 +31,12 @@ void ComputePipeline::Cleanup() {
}
}
already_AddRefed<BindGroupLayout> ComputePipeline::GetBindGroupLayout(
uint32_t index) const {
RefPtr<BindGroupLayout> object =
new BindGroupLayout(mParent, mImplicitBindGroupLayoutIds[index]);
return object.forget();
}
} // namespace webgpu
} // namespace mozilla

View file

@ -12,17 +12,22 @@
namespace mozilla {
namespace webgpu {
class BindGroupLayout;
class Device;
class ComputePipeline final : public ObjectBase, public ChildOf<Device> {
const nsTArray<RawId> mImplicitBindGroupLayoutIds;
public:
GPU_DECL_CYCLE_COLLECTION(ComputePipeline)
GPU_DECL_JS_WRAP(ComputePipeline)
ComputePipeline(Device* const aParent, RawId aId);
const RawId mId;
ComputePipeline(Device* const aParent, RawId aId,
nsTArray<RawId>&& aImplicitBindGroupLayoutIds);
already_AddRefed<BindGroupLayout> GetBindGroupLayout(uint32_t index) const;
private:
~ComputePipeline();
void Cleanup();

View file

@ -194,15 +194,21 @@ already_AddRefed<ShaderModule> Device::CreateShaderModule(
already_AddRefed<ComputePipeline> Device::CreateComputePipeline(
const dom::GPUComputePipelineDescriptor& aDesc) {
RawId id = mBridge->DeviceCreateComputePipeline(mId, aDesc);
RefPtr<ComputePipeline> object = new ComputePipeline(this, id);
nsTArray<RawId> implicitBindGroupLayoutIds;
RawId id = mBridge->DeviceCreateComputePipeline(mId, aDesc,
&implicitBindGroupLayoutIds);
RefPtr<ComputePipeline> object =
new ComputePipeline(this, id, std::move(implicitBindGroupLayoutIds));
return object.forget();
}
already_AddRefed<RenderPipeline> Device::CreateRenderPipeline(
const dom::GPURenderPipelineDescriptor& aDesc) {
RawId id = mBridge->DeviceCreateRenderPipeline(mId, aDesc);
RefPtr<RenderPipeline> object = new RenderPipeline(this, id);
nsTArray<RawId> implicitBindGroupLayoutIds;
RawId id = mBridge->DeviceCreateRenderPipeline(mId, aDesc,
&implicitBindGroupLayoutIds);
RefPtr<RenderPipeline> object =
new RenderPipeline(this, id, std::move(implicitBindGroupLayoutIds));
return object.forget();
}

View file

@ -13,8 +13,11 @@ namespace webgpu {
GPU_IMPL_CYCLE_COLLECTION(RenderPipeline, mParent)
GPU_IMPL_JS_WRAP(RenderPipeline)
RenderPipeline::RenderPipeline(Device* const aParent, RawId aId)
: ChildOf(aParent), mId(aId) {}
RenderPipeline::RenderPipeline(Device* const aParent, RawId aId,
nsTArray<RawId>&& aImplicitBindGroupLayoutIds)
: ChildOf(aParent),
mImplicitBindGroupLayoutIds(std::move(aImplicitBindGroupLayoutIds)),
mId(aId) {}
RenderPipeline::~RenderPipeline() { Cleanup(); }
@ -28,5 +31,12 @@ void RenderPipeline::Cleanup() {
}
}
already_AddRefed<BindGroupLayout> RenderPipeline::GetBindGroupLayout(
uint32_t index) const {
RefPtr<BindGroupLayout> object =
new BindGroupLayout(mParent, mImplicitBindGroupLayoutIds[index]);
return object.forget();
}
} // namespace webgpu
} // namespace mozilla

View file

@ -15,14 +15,18 @@ namespace webgpu {
class Device;
class RenderPipeline final : public ObjectBase, public ChildOf<Device> {
const nsTArray<RawId> mImplicitBindGroupLayoutIds;
public:
GPU_DECL_CYCLE_COLLECTION(RenderPipeline)
GPU_DECL_JS_WRAP(RenderPipeline)
RenderPipeline(Device* const aParent, RawId aId);
const RawId mId;
RenderPipeline(Device* const aParent, RawId aId,
nsTArray<RawId>&& aImplicitBindGroupLayoutIds);
already_AddRefed<BindGroupLayout> GetBindGroupLayout(uint32_t index) const;
private:
virtual ~RenderPipeline();
void Cleanup();

View file

@ -36,6 +36,7 @@ parent:
async DeviceAction(RawId selfId, ByteBuf buf);
async TextureAction(RawId selfId, ByteBuf buf);
async CommandEncoderAction(RawId selfId, ByteBuf buf);
async BumpImplicitBindGroupLayout(RawId pipelineId, bool isCompute, uint32_t index);
async InstanceRequestAdapter(GPURequestAdapterOptions options, RawId[] ids) returns (RawId adapterId);
async AdapterRequestDevice(RawId selfId, GPUDeviceDescriptor desc, RawId newId);
@ -69,6 +70,7 @@ parent:
async Shutdown();
child:
async DropAction(ByteBuf buf);
async FreeAdapter(RawId id);
async FreeDevice(RawId id);
async FreePipelineLayout(RawId id);

View file

@ -14,10 +14,6 @@ NS_IMPL_CYCLE_COLLECTION(WebGPUChild)
NS_IMPL_CYCLE_COLLECTION_ROOT_NATIVE(WebGPUChild, AddRef)
NS_IMPL_CYCLE_COLLECTION_UNROOT_NATIVE(WebGPUChild, Release)
ffi::WGPUByteBuf* ToFFI(ipc::ByteBuf* x) {
return reinterpret_cast<ffi::WGPUByteBuf*>(x);
}
static ffi::WGPUClient* initialize() {
ffi::WGPUInfrastructure infra = ffi::wgpu_client_new();
return infra.client;
@ -376,21 +372,30 @@ RawId WebGPUChild::DeviceCreateShaderModule(
}
RawId WebGPUChild::DeviceCreateComputePipeline(
RawId aSelfId, const dom::GPUComputePipelineDescriptor& aDesc) {
RawId aSelfId, const dom::GPUComputePipelineDescriptor& aDesc,
nsTArray<RawId>* const aImplicitBindGroupLayoutIds) {
ffi::WGPUComputePipelineDescriptor desc = {};
nsCString label, entryPoint;
if (aDesc.mLabel.WasPassed()) {
LossyCopyUTF16toASCII(aDesc.mLabel.Value(), label);
desc.label = label.get();
}
desc.layout = aDesc.mLayout->mId;
if (aDesc.mLayout.WasPassed()) {
desc.layout = aDesc.mLayout.Value().mId;
}
desc.compute_stage.module = aDesc.mComputeStage.mModule->mId;
LossyCopyUTF16toASCII(aDesc.mComputeStage.mEntryPoint, entryPoint);
desc.compute_stage.entry_point = entryPoint.get();
ByteBuf bb;
RawId id = ffi::wgpu_client_create_compute_pipeline(mClient, aSelfId, &desc,
ToFFI(&bb));
RawId implicit_bgl_ids[WGPUMAX_BIND_GROUPS] = {};
RawId id = ffi::wgpu_client_create_compute_pipeline(
mClient, aSelfId, &desc, ToFFI(&bb), implicit_bgl_ids);
for (const auto& cur : implicit_bgl_ids) {
if (!cur) break;
aImplicitBindGroupLayoutIds->AppendElement(cur);
}
if (!SendDeviceAction(aSelfId, std::move(bb))) {
MOZ_CRASH("IPC failure");
}
@ -457,7 +462,8 @@ static ffi::WGPUDepthStencilStateDescriptor ConvertDepthStencilDescriptor(
}
RawId WebGPUChild::DeviceCreateRenderPipeline(
RawId aSelfId, const dom::GPURenderPipelineDescriptor& aDesc) {
RawId aSelfId, const dom::GPURenderPipelineDescriptor& aDesc,
nsTArray<RawId>* const aImplicitBindGroupLayoutIds) {
ffi::WGPURenderPipelineDescriptor desc = {};
nsCString label, vsEntry, fsEntry;
ffi::WGPUProgrammableStageDescriptor vertexStage = {};
@ -467,7 +473,10 @@ RawId WebGPUChild::DeviceCreateRenderPipeline(
LossyCopyUTF16toASCII(aDesc.mLabel.Value(), label);
desc.label = label.get();
}
desc.layout = aDesc.mLayout->mId;
if (aDesc.mLayout.WasPassed()) {
desc.layout = aDesc.mLayout.Value().mId;
}
vertexStage.module = aDesc.mVertexStage.mModule->mId;
LossyCopyUTF16toASCII(aDesc.mVertexStage.mEntryPoint, vsEntry);
vertexStage.entry_point = vsEntry.get();
@ -537,14 +546,26 @@ RawId WebGPUChild::DeviceCreateRenderPipeline(
desc.alpha_to_coverage_enabled = aDesc.mAlphaToCoverageEnabled;
ByteBuf bb;
RawId id = ffi::wgpu_client_create_render_pipeline(mClient, aSelfId, &desc,
ToFFI(&bb));
RawId implicit_bgl_ids[WGPUMAX_BIND_GROUPS] = {};
RawId id = ffi::wgpu_client_create_render_pipeline(
mClient, aSelfId, &desc, ToFFI(&bb), implicit_bgl_ids);
for (const auto& cur : implicit_bgl_ids) {
if (!cur) break;
aImplicitBindGroupLayoutIds->AppendElement(cur);
}
if (!SendDeviceAction(aSelfId, std::move(bb))) {
MOZ_CRASH("IPC failure");
}
return id;
}
ipc::IPCResult WebGPUChild::RecvDropAction(const ipc::ByteBuf& aByteBuf) {
const auto* byteBuf = ToFFI(&aByteBuf);
ffi::wgpu_client_drop_action(mClient, byteBuf);
return IPC_OK();
}
ipc::IPCResult WebGPUChild::RecvFreeAdapter(RawId id) {
ffi::wgpu_client_kill_adapter_id(mClient, id);
return IPC_OK();

View file

@ -64,9 +64,11 @@ class WebGPUChild final : public PWebGPUChild, public SupportsWeakPtr {
RawId DeviceCreateShaderModule(RawId aSelfId,
const dom::GPUShaderModuleDescriptor& aDesc);
RawId DeviceCreateComputePipeline(
RawId aSelfId, const dom::GPUComputePipelineDescriptor& aDesc);
RawId aSelfId, const dom::GPUComputePipelineDescriptor& aDesc,
nsTArray<RawId>* const aImplicitBindGroupLayoutIds);
RawId DeviceCreateRenderPipeline(
RawId aSelfId, const dom::GPURenderPipelineDescriptor& aDesc);
RawId aSelfId, const dom::GPURenderPipelineDescriptor& aDesc,
nsTArray<RawId>* const aImplicitBindGroupLayoutIds);
void DeviceCreateSwapChain(RawId aSelfId, const RGBDescriptor& aRgbDesc,
size_t maxBufferCount,
@ -96,6 +98,7 @@ class WebGPUChild final : public PWebGPUChild, public SupportsWeakPtr {
bool mIPCOpen;
public:
ipc::IPCResult RecvDropAction(const ipc::ByteBuf& aByteBuf);
ipc::IPCResult RecvFreeAdapter(RawId id);
ipc::IPCResult RecvFreeDevice(RawId id);
ipc::IPCResult RecvFreePipelineLayout(RawId id);

View file

@ -173,6 +173,8 @@ ipc::IPCResult WebGPUParent::RecvInstanceRequestAdapter(
ipc::IPCResult WebGPUParent::RecvAdapterRequestDevice(
RawId aSelfId, const dom::GPUDeviceDescriptor& aDesc, RawId aNewId) {
ffi::WGPUDeviceDescriptor desc = {};
desc.shader_validation = true; // required for implicit pipeline layouts
if (aDesc.mLimits.WasPassed()) {
const auto& lim = aDesc.mLimits.Value();
desc.limits.max_bind_groups = lim.mMaxBindGroups;
@ -194,7 +196,7 @@ ipc::IPCResult WebGPUParent::RecvAdapterRequestDevice(
} else {
ffi::wgpu_server_fill_default_limits(&desc.limits);
}
// TODO: fill up the descriptor
ffi::wgpu_server_adapter_request_device(mContext, aSelfId, &desc, aNewId);
return IPC_OK();
}
@ -591,22 +593,40 @@ ipc::IPCResult WebGPUParent::RecvShutdown() {
ipc::IPCResult WebGPUParent::RecvDeviceAction(RawId aSelf,
const ipc::ByteBuf& aByteBuf) {
ffi::wgpu_server_device_action(
mContext, aSelf, reinterpret_cast<const ffi::WGPUByteBuf*>(&aByteBuf));
ipc::ByteBuf byteBuf;
ffi::wgpu_server_device_action(mContext, aSelf, ToFFI(&aByteBuf),
ToFFI(&byteBuf));
if (byteBuf.mData) {
if (!SendDropAction(std::move(byteBuf))) {
NS_WARNING("Unable to set a drop action!");
}
}
return IPC_OK();
}
ipc::IPCResult WebGPUParent::RecvTextureAction(RawId aSelf,
const ipc::ByteBuf& aByteBuf) {
ffi::wgpu_server_texture_action(
mContext, aSelf, reinterpret_cast<const ffi::WGPUByteBuf*>(&aByteBuf));
ffi::wgpu_server_texture_action(mContext, aSelf, ToFFI(&aByteBuf));
return IPC_OK();
}
ipc::IPCResult WebGPUParent::RecvCommandEncoderAction(
RawId aSelf, const ipc::ByteBuf& aByteBuf) {
ffi::wgpu_server_command_encoder_action(
mContext, aSelf, reinterpret_cast<const ffi::WGPUByteBuf*>(&aByteBuf));
ffi::wgpu_server_command_encoder_action(mContext, aSelf, ToFFI(&aByteBuf));
return IPC_OK();
}
ipc::IPCResult WebGPUParent::RecvBumpImplicitBindGroupLayout(RawId pipelineId,
bool isCompute,
uint32_t index) {
if (isCompute) {
ffi::wgpu_server_compute_pipeline_get_bind_group_layout(mContext,
pipelineId, index);
} else {
ffi::wgpu_server_render_pipeline_get_bind_group_layout(mContext, pipelineId,
index);
}
return IPC_OK();
}

View file

@ -70,6 +70,9 @@ class WebGPUParent final : public PWebGPUParent {
ipc::IPCResult RecvTextureAction(RawId aSelf, const ipc::ByteBuf& aByteBuf);
ipc::IPCResult RecvCommandEncoderAction(RawId aSelf,
const ipc::ByteBuf& aByteBuf);
ipc::IPCResult RecvBumpImplicitBindGroupLayout(RawId pipelineId,
bool isCompute,
uint32_t index);
ipc::IPCResult RecvShutdown();

View file

@ -31,8 +31,7 @@ DEFINE_IPC_SERIALIZER_WITHOUT_FIELDS(mozilla::dom::GPUCommandBufferDescriptor);
DEFINE_IPC_SERIALIZER_WITH_FIELDS(mozilla::dom::GPURequestAdapterOptions,
mPowerPreference);
DEFINE_IPC_SERIALIZER_WITH_FIELDS(mozilla::dom::GPUExtensions,
mAnisotropicFiltering);
DEFINE_IPC_SERIALIZER_WITHOUT_FIELDS(mozilla::dom::GPUExtensions);
DEFINE_IPC_SERIALIZER_WITH_FIELDS(mozilla::dom::GPULimits, mMaxBindGroups);
DEFINE_IPC_SERIALIZER_WITH_FIELDS(mozilla::dom::GPUDeviceDescriptor,
mExtensions, mLimits);

View file

@ -100,7 +100,6 @@ interface GPUAdapter {
GPUAdapter includes GPUObjectBase;
dictionary GPUExtensions {
boolean anisotropicFiltering = false;
};
dictionary GPULimits {
@ -412,7 +411,8 @@ GPUSampler includes GPUObjectBase;
enum GPUTextureComponentType {
"float",
"sint",
"uint"
"uint",
"depth-comparison"
};
// ****************************************************************************
@ -659,7 +659,11 @@ GPUShaderModule includes GPUObjectBase;
// Common stuff for ComputePipeline and RenderPipeline
dictionary GPUPipelineDescriptorBase : GPUObjectDescriptorBase {
required GPUPipelineLayout layout;
GPUPipelineLayout layout;
};
interface mixin GPUPipelineBase {
GPUBindGroupLayout getBindGroupLayout(unsigned long index);
};
dictionary GPUProgrammableStageDescriptor {
@ -677,6 +681,7 @@ dictionary GPUComputePipelineDescriptor : GPUPipelineDescriptorBase {
interface GPUComputePipeline {
};
GPUComputePipeline includes GPUObjectBase;
GPUComputePipeline includes GPUPipelineBase;
// GPURenderPipeline
enum GPUPrimitiveTopology {
@ -727,6 +732,7 @@ dictionary GPURenderPipelineDescriptor : GPUPipelineDescriptorBase {
interface GPURenderPipeline {
};
GPURenderPipeline includes GPUObjectBase;
GPURenderPipeline includes GPUPipelineBase;
// ****************************************************************************
// COMMAND RECORDING (Command buffer and all relevant structures)

34
gfx/wgpu/Cargo.lock generated
View file

@ -316,7 +316,7 @@ version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb582b60359da160a9477ee80f15c8d784c477e69c217ef2cdd4169c24ea380f"
dependencies = [
"proc-macro2 1.0.18",
"proc-macro2 1.0.24",
"quote 1.0.7",
"syn",
]
@ -867,7 +867,7 @@ dependencies = [
[[package]]
name = "naga"
version = "0.2.0"
source = "git+https://github.com/gfx-rs/naga?rev=aa35110471ee7915e1f4e1de61ea41f2f32f92c4#aa35110471ee7915e1f4e1de61ea41f2f32f92c4"
source = "git+https://github.com/gfx-rs/naga?rev=4d4e1cd4cbfad2b81264a7239a336b6ec1346611#4d4e1cd4cbfad2b81264a7239a336b6ec1346611"
dependencies = [
"bitflags",
"fxhash",
@ -969,7 +969,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ffa5a33ddddfee04c0283a7653987d634e880347e96b5b2ed64de07efb59db9d"
dependencies = [
"proc-macro-crate",
"proc-macro2 1.0.18",
"proc-macro2 1.0.24",
"quote 1.0.7",
"syn",
]
@ -1117,9 +1117,9 @@ dependencies = [
[[package]]
name = "proc-macro2"
version = "1.0.18"
version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "beae6331a816b1f65d04c45b078fd8e6c93e8071771f41b8163255bbd8d7c8fa"
checksum = "1e0704ee1a7e00d7bb417d0770ea303c1bccbabf0ef1667dae92b5967f5f8a71"
dependencies = [
"unicode-xid 0.2.0",
]
@ -1145,7 +1145,7 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa563d17ecb180e500da1cfd2b028310ac758de548efdd203e18f283af693f37"
dependencies = [
"proc-macro2 1.0.18",
"proc-macro2 1.0.24",
]
[[package]]
@ -1315,7 +1315,7 @@ version = "1.0.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f2c3ac8e6ca1e9c80b8be1023940162bf81ae3cffbb1809474152f2ce1eb250"
dependencies = [
"proc-macro2 1.0.18",
"proc-macro2 1.0.24",
"quote 1.0.7",
"syn",
]
@ -1409,11 +1409,11 @@ dependencies = [
[[package]]
name = "syn"
version = "1.0.31"
version = "1.0.48"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5304cfdf27365b7585c25d4af91b35016ed21ef88f17ced89c7093b43dba8b6"
checksum = "cc371affeffc477f42a221a1e4297aedcea33d47d19b61455588bd9d8f6b19ac"
dependencies = [
"proc-macro2 1.0.18",
"proc-macro2 1.0.24",
"quote 1.0.7",
"unicode-xid 0.2.0",
]
@ -1429,20 +1429,20 @@ dependencies = [
[[package]]
name = "thiserror"
version = "1.0.20"
version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7dfdd070ccd8ccb78f4ad66bf1982dc37f620ef696c6b5028fe2ed83dd3d0d08"
checksum = "0e9ae34b84616eedaaf1e9dd6026dbe00dcafa92aa0c8077cb69df1fcfe5e53e"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.20"
version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd80fc12f73063ac132ac92aceea36734f04a1d93c1240c6944e23a3b8841793"
checksum = "9ba20f23e85b10754cd195504aebf6a27e2e6cbe28c17778a0c930724628dd56"
dependencies = [
"proc-macro2 1.0.18",
"proc-macro2 1.0.24",
"quote 1.0.7",
"syn",
]
@ -1600,7 +1600,7 @@ dependencies = [
"bumpalo",
"lazy_static",
"log",
"proc-macro2 1.0.18",
"proc-macro2 1.0.24",
"quote 1.0.7",
"syn",
"wasm-bindgen-shared",
@ -1622,7 +1622,7 @@ version = "0.2.63"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3156052d8ec77142051a533cdd686cba889537b213f948cd1d20869926e68e92"
dependencies = [
"proc-macro2 1.0.18",
"proc-macro2 1.0.24",
"quote 1.0.7",
"syn",
"wasm-bindgen-backend",

View file

@ -40,7 +40,7 @@ gfx-memory = "0.2"
[dependencies.naga]
version = "0.2"
git = "https://github.com/gfx-rs/naga"
rev = "aa35110471ee7915e1f4e1de61ea41f2f32f92c4"
rev = "4d4e1cd4cbfad2b81264a7239a336b6ec1346611"
features = ["spv-in", "spv-out", "wgsl-in"]
[dependencies.wgt]

View file

@ -2512,7 +2512,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let module = if device.private_features.shader_validation {
// Parse the given shader code and store its representation.
let spv_iter = spv.into_iter().cloned();
naga::front::spv::Parser::new(spv_iter)
naga::front::spv::Parser::new(spv_iter, &Default::default())
.parse()
.map_err(|err| {
// TODO: eventually, when Naga gets support for all features,

View file

@ -132,14 +132,18 @@ fn get_aligned_type_size(
Ti::Pointer { .. } => 4,
Ti::Array {
base,
size: naga::ArraySize::Static(count),
size: naga::ArraySize::Constant(const_handle),
stride,
} => {
let base_size = match stride {
Some(stride) => stride.get() as wgt::BufferAddress,
None => get_aligned_type_size(module, base, false),
};
base_size * count as wgt::BufferAddress
let count = match module.constants[const_handle].inner {
naga::ConstantInner::Uint(count) => count,
ref other => panic!("Unexpected array size {:?}", other),
};
base_size * count
}
Ti::Array {
base,
@ -786,7 +790,7 @@ fn derive_binding_type(
dynamic,
min_binding_size: wgt::BufferSize::new(actual_size),
},
naga::StorageClass::StorageBuffer => BindingType::StorageBuffer {
naga::StorageClass::Storage => BindingType::StorageBuffer {
dynamic,
min_binding_size: wgt::BufferSize::new(actual_size),
readonly: !usage.contains(naga::GlobalUse::STORE),

View file

@ -4,7 +4,6 @@
// The `broken_intra_doc_links` is a new name, and will fail if built on the old compiler.
#![allow(unknown_lints)]
// The intra doc links to the wgpu crate in this crate actually succesfully link to the types in the wgpu crate, when built from the wgpu crate.
// However when building from both the wgpu crate or this crate cargo doc will claim all the links cannot be resolved
// despite the fact that it works fine when it needs to.

View file

@ -16,6 +16,7 @@ typedef uint8_t WGPUOption_NonZeroU8;
typedef uint64_t WGPUOption_AdapterId;
typedef uint64_t WGPUOption_BufferId;
typedef uint64_t WGPUOption_PipelineLayoutId;
typedef uint64_t WGPUOption_BindGroupLayoutId;
typedef uint64_t WGPUOption_SamplerId;
typedef uint64_t WGPUOption_SurfaceId;
typedef uint64_t WGPUOption_TextureViewId;
@ -30,7 +31,8 @@ style = "tag"
[export]
prefix = "WGPU"
exclude = [
"Option_AdapterId", "Option_BufferId", "Option_PipelineLayoutId", "Option_SamplerId", "Option_SurfaceId", "Option_TextureViewId",
"Option_AdapterId", "Option_BufferId", "Option_PipelineLayoutId", "Option_BindGroupLayoutId",
"Option_SamplerId", "Option_SurfaceId", "Option_TextureViewId",
"Option_BufferSize", "Option_NonZeroU32", "Option_NonZeroU8",
]

View file

@ -2,7 +2,10 @@
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
use crate::{cow_label, ByteBuf, CommandEncoderAction, DeviceAction, RawString, TextureAction};
use crate::{
cow_label, ByteBuf, CommandEncoderAction, DeviceAction, DropAction, ImplicitLayout, RawString,
TextureAction,
};
use wgc::{hub::IdentityManager, id};
use wgt::Backend;
@ -13,20 +16,13 @@ use parking_lot::Mutex;
use std::{
borrow::Cow,
mem,
num::{NonZeroU32, NonZeroU8},
ptr, slice,
};
fn make_byte_buf<T: serde::Serialize>(data: &T) -> ByteBuf {
let vec = bincode::serialize(data).unwrap();
let bb = ByteBuf {
data: vec.as_ptr(),
len: vec.len(),
capacity: vec.capacity(),
};
mem::forget(vec);
bb
ByteBuf::from_vec(vec)
}
#[repr(C)]
@ -191,6 +187,19 @@ struct IdentityHub {
samplers: IdentityManager,
}
impl ImplicitLayout<'_> {
fn new(identities: &mut IdentityHub, backend: Backend) -> Self {
ImplicitLayout {
pipeline: identities.pipeline_layouts.alloc(backend),
bind_groups: Cow::Owned(
(0..wgc::MAX_BIND_GROUPS)
.map(|_| identities.bind_group_layouts.alloc(backend))
.collect(),
),
}
}
}
#[derive(Debug, Default)]
struct Identities {
surfaces: IdentityManager,
@ -219,6 +228,22 @@ pub struct Client {
identities: Mutex<Identities>,
}
#[no_mangle]
pub unsafe extern "C" fn wgpu_client_drop_action(client: &mut Client, byte_buf: &ByteBuf) {
let mut cursor = std::io::Cursor::new(byte_buf.as_slice());
let mut identities = client.identities.lock();
while let Ok(action) = bincode::deserialize_from(&mut cursor) {
match action {
DropAction::Buffer(id) => identities.select(id.backend()).buffers.free(id),
DropAction::Texture(id) => identities.select(id.backend()).textures.free(id),
DropAction::Sampler(id) => identities.select(id.backend()).samplers.free(id),
DropAction::BindGroupLayout(id) => {
identities.select(id.backend()).bind_group_layouts.free(id)
}
}
}
}
#[repr(C)]
#[derive(Debug)]
pub struct Infrastructure {
@ -613,8 +638,13 @@ pub unsafe extern "C" fn wgpu_client_create_bind_group_layout(
RawBindingType::Sampler => wgt::BindingType::Sampler { comparison: false },
RawBindingType::ComparisonSampler => wgt::BindingType::Sampler { comparison: true },
RawBindingType::SampledTexture => wgt::BindingType::SampledTexture {
dimension: *entry.view_dimension.unwrap(),
component_type: *entry.texture_component_type.unwrap(),
//TODO: the spec has a bug here
dimension: *entry
.view_dimension
.unwrap_or(&wgt::TextureViewDimension::D2),
component_type: *entry
.texture_component_type
.unwrap_or(&wgt::TextureComponentType::Float),
multisampled: entry.multisampled,
},
RawBindingType::ReadonlyStorageTexture => wgt::BindingType::StorageTexture {
@ -763,12 +793,15 @@ pub unsafe extern "C" fn wgpu_client_create_shader_module(
.alloc(backend);
assert!(!desc.spirv_words.is_null());
let data = Cow::Borrowed(slice::from_raw_parts(
desc.spirv_words,
desc.spirv_words_length,
));
let spv = Cow::Borrowed(if desc.spirv_words.is_null() {
&[][..]
} else {
slice::from_raw_parts(desc.spirv_words, desc.spirv_words_length)
});
let action = DeviceAction::CreateShaderModule(id, data);
let wgsl = cow_label(&desc.wgsl_chars).unwrap_or_default();
let action = DeviceAction::CreateShaderModule(id, spv, wgsl);
*bb = make_byte_buf(&action);
id
}
@ -789,14 +822,11 @@ pub unsafe extern "C" fn wgpu_client_create_compute_pipeline(
device_id: id::DeviceId,
desc: &ComputePipelineDescriptor,
bb: &mut ByteBuf,
implicit_bind_group_layout_ids: *mut Option<id::BindGroupLayoutId>,
) -> id::ComputePipelineId {
let backend = device_id.backend();
let id = client
.identities
.lock()
.select(backend)
.compute_pipelines
.alloc(backend);
let mut identities = client.identities.lock();
let id = identities.select(backend).compute_pipelines.alloc(backend);
let wgpu_desc = wgc::pipeline::ComputePipelineDescriptor {
label: cow_label(&desc.label),
@ -804,7 +834,18 @@ pub unsafe extern "C" fn wgpu_client_create_compute_pipeline(
compute_stage: desc.compute_stage.to_wgpu(),
};
let action = DeviceAction::CreateComputePipeline(id, wgpu_desc);
let implicit = match desc.layout {
Some(_) => None,
None => {
let implicit = ImplicitLayout::new(identities.select(backend), backend);
for (i, bgl_id) in implicit.bind_groups.iter().enumerate() {
*implicit_bind_group_layout_ids.add(i) = Some(*bgl_id);
}
Some(implicit)
}
};
let action = DeviceAction::CreateComputePipeline(id, wgpu_desc, implicit);
*bb = make_byte_buf(&action);
id
}
@ -825,14 +866,11 @@ pub unsafe extern "C" fn wgpu_client_create_render_pipeline(
device_id: id::DeviceId,
desc: &RenderPipelineDescriptor,
bb: &mut ByteBuf,
implicit_bind_group_layout_ids: *mut Option<id::BindGroupLayoutId>,
) -> id::RenderPipelineId {
let backend = device_id.backend();
let id = client
.identities
.lock()
.select(backend)
.render_pipelines
.alloc(backend);
let mut identities = client.identities.lock();
let id = identities.select(backend).render_pipelines.alloc(backend);
let wgpu_desc = wgc::pipeline::RenderPipelineDescriptor {
label: cow_label(&desc.label),
@ -875,7 +913,18 @@ pub unsafe extern "C" fn wgpu_client_create_render_pipeline(
alpha_to_coverage_enabled: desc.alpha_to_coverage_enabled,
};
let action = DeviceAction::CreateRenderPipeline(id, wgpu_desc);
let implicit = match desc.layout {
Some(_) => None,
None => {
let implicit = ImplicitLayout::new(identities.select(backend), backend);
for (i, bgl_id) in implicit.bind_groups.iter().enumerate() {
*implicit_bind_group_layout_ids.add(i) = Some(*bgl_id);
}
Some(implicit)
}
};
let action = DeviceAction::CreateRenderPipeline(id, wgpu_desc, implicit);
*bb = make_byte_buf(&action);
id
}

View file

@ -28,6 +28,7 @@ impl<I: id::TypedId + Clone + std::fmt::Debug> wgc::hub::IdentityHandler<I>
}
}
//TODO: remove this in favor of `DropAction` that could be sent over IPC.
#[repr(C)]
pub struct IdentityRecyclerFactory {
param: FactoryParam,

View file

@ -12,7 +12,7 @@ pub mod server;
pub use wgc::device::trace::Command as CommandEncoderAction;
use std::{borrow::Cow, slice};
use std::{borrow::Cow, mem, slice};
type RawString = *const std::os::raw::c_char;
@ -35,11 +35,35 @@ pub struct ByteBuf {
}
impl ByteBuf {
fn from_vec(vec: Vec<u8>) -> Self {
if vec.is_empty() {
ByteBuf {
data: std::ptr::null(),
len: 0,
capacity: 0,
}
} else {
let bb = ByteBuf {
data: vec.as_ptr(),
len: vec.len(),
capacity: vec.capacity(),
};
mem::forget(vec);
bb
}
}
unsafe fn as_slice(&self) -> &[u8] {
slice::from_raw_parts(self.data, self.len)
}
}
#[derive(serde::Serialize, serde::Deserialize)]
struct ImplicitLayout<'a> {
pipeline: id::PipelineLayoutId,
bind_groups: Cow<'a, [id::BindGroupLayoutId]>,
}
#[derive(serde::Serialize, serde::Deserialize)]
enum DeviceAction<'a> {
CreateBuffer(id::BufferId, wgc::resource::BufferDescriptor<'a>),
@ -54,14 +78,16 @@ enum DeviceAction<'a> {
wgc::binding_model::PipelineLayoutDescriptor<'a>,
),
CreateBindGroup(id::BindGroupId, wgc::binding_model::BindGroupDescriptor<'a>),
CreateShaderModule(id::ShaderModuleId, Cow<'a, [u32]>),
CreateShaderModule(id::ShaderModuleId, Cow<'a, [u32]>, Cow<'a, str>),
CreateComputePipeline(
id::ComputePipelineId,
wgc::pipeline::ComputePipelineDescriptor<'a>,
Option<ImplicitLayout<'a>>,
),
CreateRenderPipeline(
id::RenderPipelineId,
wgc::pipeline::RenderPipelineDescriptor<'a>,
Option<ImplicitLayout<'a>>,
),
CreateRenderBundle(
id::RenderBundleId,
@ -78,3 +104,11 @@ enum DeviceAction<'a> {
enum TextureAction<'a> {
CreateView(id::TextureViewId, wgc::resource::TextureViewDescriptor<'a>),
}
#[derive(serde::Serialize, serde::Deserialize)]
enum DropAction {
Buffer(id::BufferId),
Texture(id::TextureId),
Sampler(id::SamplerId),
BindGroupLayout(id::BindGroupLayoutId),
}

View file

@ -4,7 +4,7 @@
use crate::{
cow_label, identity::IdentityRecyclerFactory, ByteBuf, CommandEncoderAction, DeviceAction,
RawString, TextureAction,
DropAction, RawString, TextureAction,
};
use wgc::{gfx_select, id};
@ -164,7 +164,11 @@ pub extern "C" fn wgpu_server_buffer_drop(global: &Global, self_id: id::BufferId
}
trait GlobalExt {
fn device_action<B: wgc::hub::GfxBackend>(&self, self_id: id::DeviceId, action: DeviceAction);
fn device_action<B: wgc::hub::GfxBackend>(
&self,
self_id: id::DeviceId,
action: DeviceAction,
) -> Vec<u8>;
fn texture_action<B: wgc::hub::GfxBackend>(
&self,
self_id: id::TextureId,
@ -178,8 +182,12 @@ trait GlobalExt {
}
impl GlobalExt for Global {
fn device_action<B: wgc::hub::GfxBackend>(&self, self_id: id::DeviceId, action: DeviceAction) {
let implicit_ids = None; //TODO
fn device_action<B: wgc::hub::GfxBackend>(
&self,
self_id: id::DeviceId,
action: DeviceAction,
) -> Vec<u8> {
let mut drop_actions = Vec::new();
match action {
DeviceAction::CreateBuffer(id, desc) => {
self.device_create_buffer::<B>(self_id, &desc, id).unwrap();
@ -202,21 +210,54 @@ impl GlobalExt for Global {
self.device_create_bind_group::<B>(self_id, &desc, id)
.unwrap();
}
DeviceAction::CreateShaderModule(id, spirv) => {
self.device_create_shader_module::<B>(
self_id,
wgc::pipeline::ShaderModuleSource::SpirV(spirv),
id,
)
.unwrap();
}
DeviceAction::CreateComputePipeline(id, desc) => {
self.device_create_compute_pipeline::<B>(self_id, &desc, id, implicit_ids)
DeviceAction::CreateShaderModule(id, spirv, wgsl) => {
let source = if spirv.is_empty() {
wgc::pipeline::ShaderModuleSource::Wgsl(wgsl)
} else {
wgc::pipeline::ShaderModuleSource::SpirV(spirv)
};
self.device_create_shader_module::<B>(self_id, source, id)
.unwrap();
}
DeviceAction::CreateRenderPipeline(id, desc) => {
self.device_create_render_pipeline::<B>(self_id, &desc, id, implicit_ids)
DeviceAction::CreateComputePipeline(id, desc, implicit) => {
let implicit_ids = implicit
.as_ref()
.map(|imp| wgc::device::ImplicitPipelineIds {
root_id: imp.pipeline,
group_ids: &imp.bind_groups,
});
let (_, group_count) = self
.device_create_compute_pipeline::<B>(self_id, &desc, id, implicit_ids)
.unwrap();
if let Some(ref imp) = implicit {
for &bgl_id in imp.bind_groups[group_count as usize..].iter() {
bincode::serialize_into(
&mut drop_actions,
&DropAction::BindGroupLayout(bgl_id),
)
.unwrap();
}
}
}
DeviceAction::CreateRenderPipeline(id, desc, implicit) => {
let implicit_ids = implicit
.as_ref()
.map(|imp| wgc::device::ImplicitPipelineIds {
root_id: imp.pipeline,
group_ids: &imp.bind_groups,
});
let (_, group_count) = self
.device_create_render_pipeline::<B>(self_id, &desc, id, implicit_ids)
.unwrap();
if let Some(ref imp) = implicit {
for &bgl_id in imp.bind_groups[group_count as usize..].iter() {
bincode::serialize_into(
&mut drop_actions,
&DropAction::BindGroupLayout(bgl_id),
)
.unwrap();
}
}
}
DeviceAction::CreateRenderBundle(_id, desc, _base) => {
wgc::command::RenderBundleEncoder::new(&desc, self_id, None).unwrap();
@ -226,6 +267,7 @@ impl GlobalExt for Global {
.unwrap();
}
}
drop_actions
}
fn texture_action<B: wgc::hub::GfxBackend>(
@ -292,9 +334,11 @@ pub unsafe extern "C" fn wgpu_server_device_action(
global: &Global,
self_id: id::DeviceId,
byte_buf: &ByteBuf,
drop_byte_buf: &mut ByteBuf,
) {
let action = bincode::deserialize(byte_buf.as_slice()).unwrap();
gfx_select!(self_id => global.device_action(self_id, action));
let drop_actions = gfx_select!(self_id => global.device_action(self_id, action));
*drop_byte_buf = ByteBuf::from_vec(drop_actions);
}
#[no_mangle]
@ -465,3 +509,21 @@ pub extern "C" fn wgpu_server_texture_view_drop(global: &Global, self_id: id::Te
pub extern "C" fn wgpu_server_sampler_drop(global: &Global, self_id: id::SamplerId) {
gfx_select!(self_id => global.sampler_drop(self_id));
}
#[no_mangle]
pub extern "C" fn wgpu_server_compute_pipeline_get_bind_group_layout(
global: &Global,
self_id: id::ComputePipelineId,
index: u32,
) -> id::BindGroupLayoutId {
gfx_select!(self_id => global.compute_pipeline_get_bind_group_layout(self_id, index)).unwrap()
}
#[no_mangle]
pub extern "C" fn wgpu_server_render_pipeline_get_bind_group_layout(
global: &Global,
self_id: id::RenderPipelineId,
index: u32,
) -> id::BindGroupLayoutId {
gfx_select!(self_id => global.render_pipeline_get_bind_group_layout(self_id, index)).unwrap()
}

View file

@ -9,6 +9,9 @@
// Prelude of types necessary before including wgpu_ffi_generated.h
namespace mozilla {
namespace ipc {
class ByteBuf;
} // namespace ipc
namespace webgpu {
namespace ffi {
@ -23,6 +26,14 @@ extern "C" {
#undef WGPU_FUNC
} // namespace ffi
inline ffi::WGPUByteBuf* ToFFI(ipc::ByteBuf* x) {
return reinterpret_cast<ffi::WGPUByteBuf*>(x);
}
inline const ffi::WGPUByteBuf* ToFFI(const ipc::ByteBuf* x) {
return reinterpret_cast<const ffi::WGPUByteBuf*>(x);
}
} // namespace webgpu
} // namespace mozilla

File diff suppressed because one or more lines are too long

View file

@ -16,7 +16,7 @@ log = "0.4"
num-traits = "0.2"
spirv = { package = "spirv_headers", version = "1.4.2", optional = true }
pomelo = { version = "0.1.4", optional = true }
thiserror = "1.0"
thiserror = "1.0.21"
serde = { version = "1.0", features = ["derive"], optional = true }
petgraph = { version ="0.5", optional = true }

View file

@ -22,7 +22,7 @@ SPIR-V (binary) | :construction: | |
WGSL | | |
Metal | :construction: | |
HLSL | | |
GLSL | | |
GLSL | :construction: | |
AIR | | |
DXIR | | |
DXIL | | |

View file

@ -1,8 +1,16 @@
use serde::{Deserialize, Serialize};
use std::{env, fs, path::Path};
#[derive(Hash, PartialEq, Eq, Serialize, Deserialize)]
enum Stage {
Vertex,
Fragment,
Compute,
}
#[derive(Hash, PartialEq, Eq, Serialize, Deserialize)]
struct BindSource {
stage: Stage,
group: u32,
binding: u32,
}
@ -21,6 +29,8 @@ struct BindTarget {
#[derive(Default, Serialize, Deserialize)]
struct Parameters {
#[serde(default)]
spv_flow_dump_prefix: String,
metal_bindings: naga::FastHashMap<BindSource, BindTarget>,
}
@ -33,6 +43,13 @@ fn main() {
println!("Call with <input> <output>");
return;
}
let param_path = std::path::PathBuf::from(&args[1]).with_extension("param.ron");
let params = match fs::read_to_string(param_path) {
Ok(string) => ron::de::from_str(&string).unwrap(),
Err(_) => Parameters::default(),
};
let module = match Path::new(&args[1])
.extension()
.expect("Input has no extension?")
@ -41,8 +58,15 @@ fn main() {
{
#[cfg(feature = "spv-in")]
"spv" => {
let options = naga::front::spv::Options {
flow_graph_dump_prefix: if params.spv_flow_dump_prefix.is_empty() {
None
} else {
Some(params.spv_flow_dump_prefix.into())
},
};
let input = fs::read(&args[1]).unwrap();
naga::front::spv::parse_u8_slice(&input).unwrap()
naga::front::spv::parse_u8_slice(&input, &options).unwrap()
}
#[cfg(feature = "wgsl-in")]
"wgsl" => {
@ -52,17 +76,35 @@ fn main() {
#[cfg(feature = "glsl-in")]
"vert" => {
let input = fs::read_to_string(&args[1]).unwrap();
naga::front::glsl::parse_str(&input, "main", naga::ShaderStage::Vertex).unwrap()
naga::front::glsl::parse_str(
&input,
"main",
naga::ShaderStage::Vertex,
Default::default(),
)
.unwrap()
}
#[cfg(feature = "glsl-in")]
"frag" => {
let input = fs::read_to_string(&args[1]).unwrap();
naga::front::glsl::parse_str(&input, "main", naga::ShaderStage::Fragment).unwrap()
naga::front::glsl::parse_str(
&input,
"main",
naga::ShaderStage::Fragment,
Default::default(),
)
.unwrap()
}
#[cfg(feature = "glsl-in")]
"comp" => {
let input = fs::read_to_string(&args[1]).unwrap();
naga::front::glsl::parse_str(&input, "main", naga::ShaderStage::Compute).unwrap()
naga::front::glsl::parse_str(
&input,
"main",
naga::ShaderStage::Compute,
Default::default(),
)
.unwrap()
}
#[cfg(feature = "deserialize")]
"ron" => {
@ -83,12 +125,6 @@ fn main() {
return;
}
let param_path = std::path::PathBuf::from(&args[1]).with_extension("param.ron");
let params = match fs::read_to_string(param_path) {
Ok(string) => ron::de::from_str(&string).unwrap(),
Err(_) => Parameters::default(),
};
match Path::new(&args[2])
.extension()
.expect("Output has no extension?")
@ -102,6 +138,11 @@ fn main() {
for (key, value) in params.metal_bindings {
binding_map.insert(
msl::BindSource {
stage: match key.stage {
Stage::Vertex => naga::ShaderStage::Vertex,
Stage::Fragment => naga::ShaderStage::Fragment,
Stage::Compute => naga::ShaderStage::Compute,
},
group: key.group,
binding: key.binding,
},
@ -114,9 +155,11 @@ fn main() {
);
}
let options = msl::Options {
binding_map: &binding_map,
lang_version: (1, 0),
spirv_cross_compatibility: false,
binding_map,
};
let msl = msl::write_string(&module, options).unwrap();
let msl = msl::write_string(&module, &options).unwrap();
fs::write(&args[2], msl).unwrap();
}
#[cfg(feature = "spv-out")]
@ -198,7 +241,10 @@ fn main() {
}
other => {
let _ = params;
panic!("Unknown output extension: {}", other);
panic!(
"Unknown output extension: {}, forgot to enable a feature?",
other
);
}
}
}

View file

@ -1,11 +1,8 @@
use std::{cmp::Ordering, fmt, hash, marker::PhantomData, num::NonZeroU32};
/// An unique index in the arena array that a handle points to.
///
/// This type is independent of `spv::Word`. `spv::Word` is used in data
/// representation. It holds a SPIR-V and refers to that instruction. In
/// structured representation, we use Handle to refer to an SPIR-V instruction.
/// `Index` is an implementation detail to `Handle`.
/// The "non-zero" part ensures that an `Option<Handle<T>>` has
/// the same size and representation as `Handle<T>`.
type Index = NonZeroU32;
/// A strongly typed reference to a SPIR-V element.
@ -168,6 +165,10 @@ impl<T> Arena<T> {
self.fetch_if_or_append(value, T::eq)
}
pub fn try_get(&self, handle: Handle<T>) -> Option<&T> {
self.data.get(handle.index.get() as usize - 1)
}
/// Get a mutable reference to an element in the arena.
pub fn get_mut(&mut self, handle: Handle<T>) -> &mut T {
self.data.get_mut(handle.index.get() as usize - 1).unwrap()

View file

@ -112,6 +112,7 @@ bitflags::bitflags! {
const IMAGE_LOAD_STORE = 1 << 8;
const CONSERVATIVE_DEPTH = 1 << 9;
const TEXTURE_1D = 1 << 10;
const PUSH_CONSTANT = 1 << 11;
}
}
@ -364,7 +365,7 @@ pub fn write<'a>(
}
let block = match global.class {
StorageClass::StorageBuffer | StorageClass::Uniform => true,
StorageClass::Storage | StorageClass::Uniform => true,
_ => false,
};
@ -409,14 +410,28 @@ pub fn write<'a>(
&mut buf,
"{} {}({});",
func.return_type
.map(|ty| write_type(ty, &module.types, &structs, None, &mut manager))
.map(|ty| write_type(
ty,
&module.types,
&module.constants,
&structs,
None,
&mut manager
))
.transpose()?
.as_deref()
.unwrap_or("void"),
name,
func.parameter_types
func.arguments
.iter()
.map(|ty| write_type(*ty, &module.types, &structs, None, &mut manager))
.map(|arg| write_type(
arg.ty,
&module.types,
&module.constants,
&structs,
None,
&mut manager
))
.collect::<Result<Vec<_>, _>>()?
.join(","),
)?;
@ -557,14 +572,15 @@ pub fn write<'a>(
let name = if let Some(ref binding) = global.binding {
let prefix = match global.class {
StorageClass::Constant => "const",
StorageClass::Function => "fn",
StorageClass::Input => "in",
StorageClass::Output => "out",
StorageClass::Private => "priv",
StorageClass::StorageBuffer => "buffer",
StorageClass::Storage => "buffer",
StorageClass::Uniform => "uniform",
StorageClass::Handle => "handle",
StorageClass::WorkGroup => "wg",
StorageClass::PushConstant => "pc",
};
match binding {
@ -606,7 +622,7 @@ pub fn write<'a>(
}
let block = match global.class {
StorageClass::StorageBuffer | StorageClass::Uniform => {
StorageClass::Storage | StorageClass::Uniform => {
Some(format!("global_block_{}", handle.index()))
}
_ => None,
@ -616,7 +632,14 @@ pub fn write<'a>(
&mut buf,
"{}{} {};",
write_storage_class(global.class, &mut manager)?,
write_type(global.ty, &module.types, &structs, block, &mut manager)?,
write_type(
global.ty,
&module.types,
&module.constants,
&structs,
block,
&mut manager
)?,
name
)?;
@ -635,33 +658,53 @@ pub fn write<'a>(
global_vars: &module.global_variables,
local_vars: &func.local_variables,
functions: &module.functions,
parameter_types: &func.parameter_types,
arguments: &func.arguments,
},
)?;
let args: FastHashMap<_, _> = func
.parameter_types
.arguments
.iter()
.enumerate()
.map(|(pos, _)| (pos as u32, format!("arg_{}", pos)))
.map(|(pos, arg)| {
let name = arg
.name
.clone()
.filter(|ident| is_valid_ident(ident))
.unwrap_or_else(|| format!("arg_{}", pos + 1));
(pos as u32, name)
})
.collect();
writeln!(
&mut buf,
"{} {}({}) {{",
func.return_type
.map(|ty| write_type(ty, &module.types, &structs, None, &mut manager))
.map(|ty| write_type(
ty,
&module.types,
&module.constants,
&structs,
None,
&mut manager
))
.transpose()?
.as_deref()
.unwrap_or("void"),
name,
func.parameter_types
func.arguments
.iter()
.zip(args.values())
.map::<Result<_, Error>, _>(|(ty, name)| {
let ty = write_type(*ty, &module.types, &structs, None, &mut manager)?;
Ok(format!("{} {}", ty, name))
.enumerate()
.map::<Result<_, Error>, _>(|(pos, arg)| {
let ty = write_type(
arg.ty,
&module.types,
&module.constants,
&structs,
None,
&mut manager,
)?;
Ok(format!("{} {}", ty, args[&(pos as u32)]))
})
.collect::<Result<Vec<_>, _>>()?
.join(","),
@ -682,23 +725,6 @@ pub fn write<'a>(
})
.collect();
for (handle, name) in locals.iter() {
writeln!(
&mut buf,
"\t{} {};",
write_type(
func.local_variables[*handle].ty,
&module.types,
&structs,
None,
&mut manager
)?,
name
)?;
}
writeln!(&mut buf)?;
let mut builder = StatementBuilder {
functions: &functions,
globals: &globals_lookup,
@ -707,14 +733,40 @@ pub fn write<'a>(
args: &args,
expressions: &func.expressions,
typifier: &typifier,
manager: &mut manager,
};
for (handle, name) in locals.iter() {
let var = &func.local_variables[*handle];
write!(
&mut buf,
"\t{} {}",
write_type(
var.ty,
&module.types,
&module.constants,
&structs,
None,
&mut manager
)?,
name
)?;
if let Some(init) = var.init {
write!(
&mut buf,
" = {}",
write_constant(&module.constants[init], module, &mut builder, &mut manager)?
)?;
}
writeln!(&mut buf, ";")?;
}
writeln!(&mut buf)?;
for sta in func.body.iter() {
writeln!(
&mut buf,
"{}",
write_statement(sta, module, &mut builder, 1)?
write_statement(sta, module, &mut builder, &mut manager, 1)?
)?;
}
@ -746,19 +798,19 @@ struct StatementBuilder<'a> {
args: &'a FastHashMap<u32, String>,
expressions: &'a Arena<Expression>,
typifier: &'a Typifier,
pub manager: &'a mut FeaturesManager,
}
fn write_statement<'a, 'b>(
sta: &Statement,
module: &'a Module,
builder: &'b mut StatementBuilder<'a>,
manager: &mut FeaturesManager,
indent: usize,
) -> Result<String, Error> {
Ok(match sta {
Statement::Block(block) => block
.iter()
.map(|sta| write_statement(sta, module, builder, indent))
.map(|sta| write_statement(sta, module, builder, manager, indent))
.collect::<Result<Vec<_>, _>>()?
.join("\n"),
Statement::If {
@ -772,14 +824,14 @@ fn write_statement<'a, 'b>(
&mut out,
"{}if({}) {{",
"\t".repeat(indent),
write_expression(&builder.expressions[*condition], module, builder)?
write_expression(&builder.expressions[*condition], module, builder, manager)?
)?;
for sta in accept {
writeln!(
&mut out,
"{}",
write_statement(sta, module, builder, indent + 1)?
write_statement(sta, module, builder, manager, indent + 1)?
)?;
}
@ -789,7 +841,7 @@ fn write_statement<'a, 'b>(
writeln!(
&mut out,
"{}",
write_statement(sta, module, builder, indent + 1)?
write_statement(sta, module, builder, manager, indent + 1)?
)?;
}
}
@ -809,7 +861,7 @@ fn write_statement<'a, 'b>(
&mut out,
"{}switch({}) {{",
"\t".repeat(indent),
write_expression(&builder.expressions[*selector], module, builder)?
write_expression(&builder.expressions[*selector], module, builder, manager)?
)?;
for (label, (block, fallthrough)) in cases {
@ -819,7 +871,7 @@ fn write_statement<'a, 'b>(
writeln!(
&mut out,
"{}",
write_statement(sta, module, builder, indent + 2)?
write_statement(sta, module, builder, manager, indent + 2)?
)?;
}
@ -835,7 +887,7 @@ fn write_statement<'a, 'b>(
writeln!(
&mut out,
"{}",
write_statement(sta, module, builder, indent + 2)?
write_statement(sta, module, builder, manager, indent + 2)?
)?;
}
}
@ -853,7 +905,7 @@ fn write_statement<'a, 'b>(
writeln!(
&mut out,
"{}",
write_statement(sta, module, builder, indent + 1)?
write_statement(sta, module, builder, manager, indent + 1)?
)?;
}
@ -869,7 +921,7 @@ fn write_statement<'a, 'b>(
if let Some(expr) = value {
format!(
"return {};",
write_expression(&builder.expressions[*expr], module, builder)?
write_expression(&builder.expressions[*expr], module, builder, manager)?
)
} else {
String::from("return;")
@ -879,8 +931,8 @@ fn write_statement<'a, 'b>(
Statement::Store { pointer, value } => format!(
"{}{} = {};",
"\t".repeat(indent),
write_expression(&builder.expressions[*pointer], module, builder)?,
write_expression(&builder.expressions[*value], module, builder)?
write_expression(&builder.expressions[*pointer], module, builder, manager)?,
write_expression(&builder.expressions[*value], module, builder, manager)?
),
})
}
@ -889,18 +941,19 @@ fn write_expression<'a, 'b>(
expr: &Expression,
module: &'a Module,
builder: &'b mut StatementBuilder<'a>,
manager: &mut FeaturesManager,
) -> Result<Cow<'a, str>, Error> {
Ok(match *expr {
Expression::Access { base, index } => {
let base_expr = write_expression(&builder.expressions[base], module, builder)?;
let base_expr = write_expression(&builder.expressions[base], module, builder, manager)?;
Cow::Owned(format!(
"{}[{}]",
base_expr,
write_expression(&builder.expressions[index], module, builder)?
write_expression(&builder.expressions[index], module, builder, manager)?
))
}
Expression::AccessIndex { base, index } => {
let base_expr = write_expression(&builder.expressions[base], module, builder)?;
let base_expr = write_expression(&builder.expressions[base], module, builder, manager)?;
match *builder.typifier.get(base, &module.types) {
TypeInner::Vector { .. } => Cow::Owned(format!("{}[{}]", base_expr, index)),
@ -929,12 +982,13 @@ fn write_expression<'a, 'b>(
&module.constants[constant],
module,
builder,
manager,
)?),
Expression::Compose { ty, ref components } => {
let constructor = match module.types[ty].inner {
TypeInner::Vector { size, kind, width } => format!(
"{}vec{}",
map_scalar(kind, width, builder.manager)?.prefix,
map_scalar(kind, width, manager)?.prefix,
size as u8,
),
TypeInner::Matrix {
@ -943,19 +997,31 @@ fn write_expression<'a, 'b>(
width,
} => format!(
"{}mat{}x{}",
map_scalar(crate::ScalarKind::Float, width, builder.manager)?.prefix,
map_scalar(crate::ScalarKind::Float, width, manager)?.prefix,
columns as u8,
rows as u8,
),
TypeInner::Array { .. } => {
write_type(ty, &module.types, builder.structs, None, builder.manager)?
.into_owned()
}
TypeInner::Array { .. } => write_type(
ty,
&module.types,
&module.constants,
builder.structs,
None,
manager,
)?
.into_owned(),
TypeInner::Struct { .. } => builder.structs.get(&ty).unwrap().clone(),
_ => {
return Err(Error::Custom(format!(
"Cannot compose type {}",
write_type(ty, &module.types, builder.structs, None, builder.manager)?
write_type(
ty,
&module.types,
&module.constants,
builder.structs,
None,
manager
)?
)))
}
};
@ -968,19 +1034,20 @@ fn write_expression<'a, 'b>(
.map::<Result<_, Error>, _>(|arg| Ok(write_expression(
&builder.expressions[*arg],
module,
builder
builder,
manager,
)?))
.collect::<Result<Vec<_>, _>>()?
.join(","),
))
}
Expression::FunctionParameter(pos) => Cow::Borrowed(builder.args.get(&pos).unwrap()),
Expression::FunctionArgument(pos) => Cow::Borrowed(builder.args.get(&pos).unwrap()),
Expression::GlobalVariable(handle) => Cow::Borrowed(builder.globals.get(&handle).unwrap()),
Expression::LocalVariable(handle) => {
Cow::Borrowed(builder.locals_lookup.get(&handle).unwrap())
}
Expression::Load { pointer } => {
write_expression(&builder.expressions[pointer], module, builder)?
write_expression(&builder.expressions[pointer], module, builder, manager)?
}
Expression::ImageSample {
image,
@ -989,10 +1056,11 @@ fn write_expression<'a, 'b>(
level,
depth_ref,
} => {
let image_expr = write_expression(&builder.expressions[image], module, builder)?;
write_expression(&builder.expressions[sampler], module, builder)?;
let image_expr =
write_expression(&builder.expressions[image], module, builder, manager)?;
write_expression(&builder.expressions[sampler], module, builder, manager)?;
let coordinate_expr =
write_expression(&builder.expressions[coordinate], module, builder)?;
write_expression(&builder.expressions[coordinate], module, builder, manager)?;
let size = match *builder.typifier.get(coordinate, &module.types) {
TypeInner::Vector { size, .. } => size,
@ -1009,7 +1077,7 @@ fn write_expression<'a, 'b>(
"vec{}({},{})",
size as u8 + 1,
coordinate_expr,
write_expression(&builder.expressions[depth_ref], module, builder)?
write_expression(&builder.expressions[depth_ref], module, builder, manager)?
))
} else {
coordinate_expr
@ -1022,14 +1090,16 @@ fn write_expression<'a, 'b>(
format!("textureLod({},{},0)", image_expr, coordinate_expr)
}
crate::SampleLevel::Exact(expr) => {
let level_expr = write_expression(&builder.expressions[expr], module, builder)?;
let level_expr =
write_expression(&builder.expressions[expr], module, builder, manager)?;
format!(
"textureLod({}, {}, {})",
image_expr, coordinate_expr, level_expr
)
}
crate::SampleLevel::Bias(bias) => {
let bias_expr = write_expression(&builder.expressions[bias], module, builder)?;
let bias_expr =
write_expression(&builder.expressions[bias], module, builder, manager)?;
format!("texture({},{},{})", image_expr, coordinate_expr, bias_expr)
}
})
@ -1039,9 +1109,10 @@ fn write_expression<'a, 'b>(
coordinate,
index,
} => {
let image_expr = write_expression(&builder.expressions[image], module, builder)?;
let image_expr =
write_expression(&builder.expressions[image], module, builder, manager)?;
let coordinate_expr =
write_expression(&builder.expressions[coordinate], module, builder)?;
write_expression(&builder.expressions[coordinate], module, builder, manager)?;
let (dim, arrayed, class) = match *builder.typifier.get(image, &module.types) {
TypeInner::Image {
@ -1057,15 +1128,19 @@ fn write_expression<'a, 'b>(
//TODO: fix this
let sampler_constructor = format!(
"{}sampler{}{}{}({})",
map_scalar(kind, 4, builder.manager)?.prefix,
map_scalar(kind, 4, manager)?.prefix,
ImageDimension(dim),
if multi { "MS" } else { "" },
if arrayed { "Array" } else { "" },
image_expr,
);
let index_expr =
write_expression(&builder.expressions[index.unwrap()], module, builder)?;
let index_expr = write_expression(
&builder.expressions[index.unwrap()],
module,
builder,
manager,
)?;
format!(
"texelFetch({},{},{})",
sampler_constructor, coordinate_expr, index_expr
@ -1076,7 +1151,7 @@ fn write_expression<'a, 'b>(
})
}
Expression::Unary { op, expr } => {
let base_expr = write_expression(&builder.expressions[expr], module, builder)?;
let base_expr = write_expression(&builder.expressions[expr], module, builder, manager)?;
Cow::Owned(format!(
"({} {})",
@ -1106,8 +1181,9 @@ fn write_expression<'a, 'b>(
))
}
Expression::Binary { op, left, right } => {
let left_expr = write_expression(&builder.expressions[left], module, builder)?;
let right_expr = write_expression(&builder.expressions[right], module, builder)?;
let left_expr = write_expression(&builder.expressions[left], module, builder, manager)?;
let right_expr =
write_expression(&builder.expressions[right], module, builder, manager)?;
let op_str = match op {
BinaryOperator::Add => "+",
@ -1126,15 +1202,30 @@ fn write_expression<'a, 'b>(
BinaryOperator::InclusiveOr => "|",
BinaryOperator::LogicalAnd => "&&",
BinaryOperator::LogicalOr => "||",
BinaryOperator::ShiftLeftLogical => "<<",
BinaryOperator::ShiftRightLogical => todo!(),
BinaryOperator::ShiftRightArithmetic => ">>",
BinaryOperator::ShiftLeft => "<<",
BinaryOperator::ShiftRight => ">>",
};
Cow::Owned(format!("({} {} {})", left_expr, op_str, right_expr))
}
Expression::Select {
condition,
accept,
reject,
} => {
let cond_expr =
write_expression(&builder.expressions[condition], module, builder, manager)?;
let accept_expr =
write_expression(&builder.expressions[accept], module, builder, manager)?;
let reject_expr =
write_expression(&builder.expressions[reject], module, builder, manager)?;
Cow::Owned(format!(
"({} ? {} : {})",
cond_expr, accept_expr, reject_expr
))
}
Expression::Intrinsic { fun, argument } => {
let expr = write_expression(&builder.expressions[argument], module, builder)?;
let expr = write_expression(&builder.expressions[argument], module, builder, manager)?;
Cow::Owned(format!(
"{:?}({})",
@ -1150,17 +1241,20 @@ fn write_expression<'a, 'b>(
))
}
Expression::Transpose(matrix) => {
let matrix_expr = write_expression(&builder.expressions[matrix], module, builder)?;
let matrix_expr =
write_expression(&builder.expressions[matrix], module, builder, manager)?;
Cow::Owned(format!("transpose({})", matrix_expr))
}
Expression::DotProduct(left, right) => {
let left_expr = write_expression(&builder.expressions[left], module, builder)?;
let right_expr = write_expression(&builder.expressions[right], module, builder)?;
let left_expr = write_expression(&builder.expressions[left], module, builder, manager)?;
let right_expr =
write_expression(&builder.expressions[right], module, builder, manager)?;
Cow::Owned(format!("dot({},{})", left_expr, right_expr))
}
Expression::CrossProduct(left, right) => {
let left_expr = write_expression(&builder.expressions[left], module, builder)?;
let right_expr = write_expression(&builder.expressions[right], module, builder)?;
let left_expr = write_expression(&builder.expressions[left], module, builder, manager)?;
let right_expr =
write_expression(&builder.expressions[right], module, builder, manager)?;
Cow::Owned(format!("cross({},{})", left_expr, right_expr))
}
Expression::As {
@ -1168,7 +1262,8 @@ fn write_expression<'a, 'b>(
kind,
convert,
} => {
let value_expr = write_expression(&builder.expressions[expr], module, builder)?;
let value_expr =
write_expression(&builder.expressions[expr], module, builder, manager)?;
let (source_kind, ty_expr) = match *builder.typifier.get(expr, &module.types) {
TypeInner::Scalar {
@ -1176,7 +1271,7 @@ fn write_expression<'a, 'b>(
kind: source_kind,
} => (
source_kind,
Cow::Borrowed(map_scalar(kind, width, builder.manager)?.full),
Cow::Borrowed(map_scalar(kind, width, manager)?.full),
),
TypeInner::Vector {
width,
@ -1186,7 +1281,7 @@ fn write_expression<'a, 'b>(
source_kind,
Cow::Owned(format!(
"{}vec{}",
map_scalar(kind, width, builder.manager)?.prefix,
map_scalar(kind, width, manager)?.prefix,
size as u32,
)),
),
@ -1213,7 +1308,7 @@ fn write_expression<'a, 'b>(
Cow::Owned(format!("{}({})", op, value_expr))
}
Expression::Derivative { axis, expr } => {
let expr = write_expression(&builder.expressions[expr], module, builder)?;
let expr = write_expression(&builder.expressions[expr], module, builder, manager)?;
Cow::Owned(format!(
"{}({})",
@ -1236,7 +1331,8 @@ fn write_expression<'a, 'b>(
.map::<Result<_, Error>, _>(|arg| write_expression(
&builder.expressions[*arg],
module,
builder
builder,
manager,
))
.collect::<Result<Vec<_>, _>>()?
.join(","),
@ -1245,41 +1341,42 @@ fn write_expression<'a, 'b>(
origin: crate::FunctionOrigin::External(ref name),
ref arguments,
} => match name.as_str() {
"cos" | "normalize" | "sin" => {
let expr = write_expression(&builder.expressions[arguments[0]], module, builder)?;
"cos" | "normalize" | "sin" | "length" | "abs" | "floor" | "inverse" => {
let expr =
write_expression(&builder.expressions[arguments[0]], module, builder, manager)?;
Cow::Owned(format!("{}({})", name, expr))
}
"fclamp" => {
let val = write_expression(&builder.expressions[arguments[0]], module, builder)?;
let min = write_expression(&builder.expressions[arguments[1]], module, builder)?;
let max = write_expression(&builder.expressions[arguments[2]], module, builder)?;
"fclamp" | "clamp" | "mix" | "smoothstep" => {
let x =
write_expression(&builder.expressions[arguments[0]], module, builder, manager)?;
let y =
write_expression(&builder.expressions[arguments[1]], module, builder, manager)?;
let a =
write_expression(&builder.expressions[arguments[2]], module, builder, manager)?;
Cow::Owned(format!("clamp({}, {}, {})", val, min, max))
let name = match name.as_str() {
"fclamp" => "clamp",
name => name,
};
Cow::Owned(format!("{}({}, {}, {})", name, x, y, a))
}
"atan2" => {
let x = write_expression(&builder.expressions[arguments[0]], module, builder)?;
let y = write_expression(&builder.expressions[arguments[1]], module, builder)?;
let x =
write_expression(&builder.expressions[arguments[0]], module, builder, manager)?;
let y =
write_expression(&builder.expressions[arguments[1]], module, builder, manager)?;
Cow::Owned(format!("atan({}, {})", y, x))
}
"distance" => {
let p0 = write_expression(&builder.expressions[arguments[0]], module, builder)?;
let p1 = write_expression(&builder.expressions[arguments[1]], module, builder)?;
"distance" | "dot" | "min" | "max" | "reflect" | "pow" | "step" | "cross" => {
let x =
write_expression(&builder.expressions[arguments[0]], module, builder, manager)?;
let y =
write_expression(&builder.expressions[arguments[1]], module, builder, manager)?;
Cow::Owned(format!("distance({}, {})", p0, p1))
}
"length" => {
let x = write_expression(&builder.expressions[arguments[0]], module, builder)?;
Cow::Owned(format!("length({})", x))
}
"mix" => {
let x = write_expression(&builder.expressions[arguments[0]], module, builder)?;
let y = write_expression(&builder.expressions[arguments[0]], module, builder)?;
let a = write_expression(&builder.expressions[arguments[0]], module, builder)?;
Cow::Owned(format!("mix({}, {}, {})", x, y, a))
Cow::Owned(format!("{}({}, {})", name, x, y))
}
other => {
return Err(Error::Custom(format!(
@ -1289,7 +1386,7 @@ fn write_expression<'a, 'b>(
}
},
Expression::ArrayLength(expr) => {
let base = write_expression(&builder.expressions[expr], module, builder)?;
let base = write_expression(&builder.expressions[expr], module, builder, manager)?;
Cow::Owned(format!("uint({}.length())", base))
}
})
@ -1299,6 +1396,7 @@ fn write_constant(
constant: &Constant,
module: &Module,
builder: &mut StatementBuilder<'_>,
manager: &mut FeaturesManager,
) -> Result<String, Error> {
Ok(match constant.inner {
ConstantInner::Sint(int) => int.to_string(),
@ -1316,9 +1414,10 @@ fn write_constant(
TypeInner::Array { .. } => write_type(
constant.ty,
&module.types,
&module.constants,
builder.structs,
None,
builder.manager
manager
)?,
_ =>
return Err(Error::Custom(format!(
@ -1326,15 +1425,21 @@ fn write_constant(
write_type(
constant.ty,
&module.types,
&module.constants,
builder.structs,
None,
builder.manager
manager
)?
))),
},
components
.iter()
.map(|component| write_constant(&module.constants[*component], module, builder,))
.map(|component| write_constant(
&module.constants[*component],
module,
builder,
manager
))
.collect::<Result<Vec<_>, _>>()?
.join(","),
),
@ -1390,6 +1495,7 @@ fn map_scalar(
fn write_type<'a>(
ty: Handle<Type>,
types: &Arena<Type>,
constants: &Arena<Constant>,
structs: &'a FastHashMap<Handle<Type>, String>,
block: Option<String>,
manager: &mut FeaturesManager,
@ -1417,7 +1523,9 @@ fn write_type<'a>(
rows as u8
))
}
TypeInner::Pointer { base, .. } => write_type(base, types, structs, None, manager)?,
TypeInner::Pointer { base, .. } => {
write_type(base, types, constants, structs, None, manager)?
}
TypeInner::Array { base, size, .. } => {
if let TypeInner::Array { .. } = types[base].inner {
manager.request(Features::ARRAY_OF_ARRAYS)
@ -1425,8 +1533,8 @@ fn write_type<'a>(
Cow::Owned(format!(
"{}[{}]",
write_type(base, types, structs, None, manager)?,
write_array_size(size)?
write_type(base, types, constants, structs, None, manager)?,
write_array_size(size, constants)?
))
}
TypeInner::Struct { ref members } => {
@ -1438,7 +1546,7 @@ fn write_type<'a>(
writeln!(
&mut out,
"\t{} {};",
write_type(member.ty, types, structs, None, manager)?,
write_type(member.ty, types, constants, structs, None, manager)?,
member
.name
.clone()
@ -1500,22 +1608,24 @@ fn write_storage_class(
manager: &mut FeaturesManager,
) -> Result<&'static str, Error> {
Ok(match class {
StorageClass::Constant => "",
StorageClass::Function => "",
StorageClass::Input => "in ",
StorageClass::Output => "out ",
StorageClass::Private => "",
StorageClass::StorageBuffer => {
StorageClass::Storage => {
manager.request(Features::BUFFER_STORAGE);
"buffer "
}
StorageClass::Uniform => "uniform ",
StorageClass::Handle => "uniform ",
StorageClass::WorkGroup => {
manager.request(Features::COMPUTE_SHADER);
"shared "
}
StorageClass::PushConstant => {
manager.request(Features::PUSH_CONSTANT);
""
}
})
}
@ -1534,9 +1644,12 @@ fn write_interpolation(interpolation: Interpolation) -> Result<&'static str, Err
})
}
fn write_array_size(size: ArraySize) -> Result<String, Error> {
fn write_array_size(size: ArraySize, constants: &Arena<Constant>) -> Result<String, Error> {
Ok(match size {
ArraySize::Static(size) => size.to_string(),
ArraySize::Constant(const_handle) => match constants[const_handle].inner {
ConstantInner::Uint(size) => size.to_string(),
_ => unreachable!(),
},
ArraySize::Dynamic => String::from(""),
})
}
@ -1598,7 +1711,14 @@ fn write_struct(
writeln!(
&mut tmp,
"\t{} {};",
write_type(member.ty, &module.types, &structs, None, manager)?,
write_type(
member.ty,
&module.types,
&module.constants,
&structs,
None,
manager
)?,
member
.name
.clone()
@ -1794,6 +1914,7 @@ fn collect_texture_mapping<'a>(
for func in functions {
let mut interface = Interface {
expressions: &func.expressions,
local_variables: &func.local_variables,
visitor: TextureMappingVisitor {
expressions: &func.expressions,
map: &mut mappings,

File diff suppressed because it is too large Load diff

View file

@ -235,6 +235,13 @@ pub(super) fn instruction_type_sampler(id: Word) -> Instruction {
instruction
}
pub(super) fn instruction_type_sampled_image(id: Word, image_type_id: Word) -> Instruction {
let mut instruction = Instruction::new(Op::TypeSampledImage);
instruction.set_result(id);
instruction.add_operand(image_type_id);
instruction
}
pub(super) fn instruction_type_array(
id: Word,
element_type_id: Word,
@ -399,6 +406,24 @@ pub(super) fn instruction_store(
instruction
}
pub(super) fn instruction_access_chain(
result_type_id: Word,
id: Word,
base_id: Word,
index_ids: &[Word],
) -> Instruction {
let mut instruction = Instruction::new(Op::AccessChain);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(base_id);
for index_id in index_ids {
instruction.add_operand(*index_id);
}
instruction
}
//
// Function Instructions
//
@ -449,6 +474,33 @@ pub(super) fn instruction_function_call(
//
// Image Instructions
//
pub(super) fn instruction_sampled_image(
result_type_id: Word,
id: Word,
image: Word,
sampler: Word,
) -> Instruction {
let mut instruction = Instruction::new(Op::SampledImage);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(image);
instruction.add_operand(sampler);
instruction
}
pub(super) fn instruction_image_sample_implicit_lod(
result_type_id: Word,
id: Word,
sampled_image: Word,
coordinates: Word,
) -> Instruction {
let mut instruction = Instruction::new(Op::ImageSampleImplicitLod);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(sampled_image);
instruction.add_operand(coordinates);
instruction
}
//
// Conversion Instructions

View file

@ -24,6 +24,10 @@ impl PhysicalLayout {
sink.extend(iter::once(self.bound));
sink.extend(iter::once(self.instruction_schema));
}
pub(super) fn supports_storage_buffers(&self) -> bool {
self.version >= 0x10300
}
}
impl LogicalLayout {

File diff suppressed because it is too large Load diff

View file

@ -42,8 +42,8 @@ impl Program {
pub fn binary_expr(
&mut self,
op: BinaryOperator,
left: ExpressionRule,
right: ExpressionRule,
left: &ExpressionRule,
right: &ExpressionRule,
) -> ExpressionRule {
ExpressionRule::from_expression(self.context.expressions.append(Expression::Binary {
op,
@ -57,13 +57,13 @@ impl Program {
handle: Handle<crate::Expression>,
) -> Result<&crate::TypeInner, ErrorKind> {
let functions = Arena::new(); //TODO
let parameter_types: Vec<Handle<Type>> = vec![]; //TODO
let arguments = Vec::new(); //TODO
let resolve_ctx = ResolveContext {
constants: &self.module.constants,
global_vars: &self.module.global_variables,
local_vars: &self.context.local_variables,
functions: &functions,
parameter_types: &parameter_types,
arguments: &arguments,
};
match self.context.typifier.grow(
handle,
@ -138,6 +138,7 @@ impl Context {
pub struct ExpressionRule {
pub expression: Handle<Expression>,
pub statements: Vec<Statement>,
pub sampler: Option<Handle<Expression>>,
}
impl ExpressionRule {
@ -145,6 +146,7 @@ impl ExpressionRule {
ExpressionRule {
expression,
statements: vec![],
sampler: None,
}
}
}
@ -166,12 +168,11 @@ pub struct VarDeclaration {
#[derive(Debug)]
pub enum FunctionCallKind {
TypeConstructor(Handle<Type>),
Function(Handle<Expression>),
Function(String),
}
#[derive(Debug)]
pub struct FunctionCall {
pub kind: FunctionCallKind,
pub args: Vec<Handle<Expression>>,
pub statements: Vec<Statement>,
pub args: Vec<ExpressionRule>,
}

View file

@ -21,6 +21,7 @@ pub enum ErrorKind {
VariableNotAvailable(String),
ExpectedConstant,
SemanticError(&'static str),
PreprocessorError(String),
}
impl fmt::Display for ErrorKind {
@ -53,6 +54,7 @@ impl fmt::Display for ErrorKind {
}
ErrorKind::ExpectedConstant => write!(f, "Expected constant"),
ErrorKind::SemanticError(msg) => write!(f, "Semantic error: {}", msg),
ErrorKind::PreprocessorError(val) => write!(f, "Preprocessor error: {}", val),
}
}
}

View file

@ -1,5 +1,5 @@
use super::parser::Token;
use super::{token::TokenMetadata, types::parse_type};
use super::{preprocess::LinePreProcessor, token::TokenMetadata, types::parse_type};
use std::{iter::Enumerate, str::Lines};
fn _consume_str<'a>(input: &'a str, what: &str) -> Option<&'a str> {
@ -23,6 +23,7 @@ pub struct Lexer<'a> {
line: usize,
offset: usize,
inside_comment: bool,
pub pp: LinePreProcessor,
}
impl<'a> Lexer<'a> {
@ -139,6 +140,16 @@ impl<'a> Lexer<'a> {
"break" => Some(Token::Break(meta)),
"return" => Some(Token::Return(meta)),
"discard" => Some(Token::Discard(meta)),
// selection statements
"if" => Some(Token::If(meta)),
"else" => Some(Token::Else(meta)),
"switch" => Some(Token::Switch(meta)),
"case" => Some(Token::Case(meta)),
"default" => Some(Token::Default(meta)),
// iteration statements
"while" => Some(Token::While(meta)),
"do" => Some(Token::Do(meta)),
"for" => Some(Token::For(meta)),
// types
"void" => Some(Token::Void(meta)),
word => {
@ -283,25 +294,42 @@ impl<'a> Lexer<'a> {
}
pub fn new(input: &'a str) -> Self {
let mut lines = input.lines().enumerate();
let (line, input) = lines.next().unwrap_or((0, ""));
let mut input = String::from(input);
while input.ends_with('\\') {
if let Some((_, next)) = lines.next() {
input.pop();
input.push_str(next);
} else {
break;
}
}
Lexer {
lines,
input,
line,
let mut lexer = Lexer {
lines: input.lines().enumerate(),
input: "".to_string(),
line: 0,
offset: 0,
inside_comment: false,
pp: LinePreProcessor::new(),
};
lexer.next_line();
lexer
}
fn next_line(&mut self) -> bool {
if let Some((line, input)) = self.lines.next() {
let mut input = String::from(input);
while input.ends_with('\\') {
if let Some((_, next)) = self.lines.next() {
input.pop();
input.push_str(next);
} else {
break;
}
}
if let Ok(processed) = self.pp.process_line(&input) {
self.input = processed.unwrap_or_default();
self.line = line;
self.offset = 0;
true
} else {
//TODO: handle preprocessor error
false
}
} else {
false
}
}
@ -331,22 +359,9 @@ impl<'a> Lexer<'a> {
self.next()
}
} else {
let (line, input) = self.lines.next()?;
let mut input = String::from(input);
while input.ends_with('\\') {
if let Some((_, next)) = self.lines.next() {
input.pop();
input.push_str(next);
} else {
break;
}
if !self.next_line() {
return None;
}
self.input = input;
self.line = line;
self.offset = 0;
self.next()
}
}

View file

@ -1,9 +1,13 @@
use crate::{Module, ShaderStage};
use crate::{FastHashMap, Module, ShaderStage};
mod lex;
#[cfg(test)]
mod lex_tests;
mod preprocess;
#[cfg(test)]
mod preprocess_tests;
mod ast;
use ast::Program;
@ -17,11 +21,17 @@ mod token;
mod types;
mod variables;
pub fn parse_str(source: &str, entry: &str, stage: ShaderStage) -> Result<Module, ParseError> {
log::debug!("------ GLSL-pomelo ------");
pub fn parse_str(
source: &str,
entry: &str,
stage: ShaderStage,
defines: FastHashMap<String, String>,
) -> Result<Module, ParseError> {
let mut program = Program::new(stage, entry);
let lex = Lexer::new(source);
let mut lex = Lexer::new(source);
lex.pp.defines = defines;
let mut parser = parser::Parser::new(&mut program);
for token in lex {

View file

@ -6,9 +6,9 @@ pomelo! {
%include {
use super::super::{error::ErrorKind, token::*, ast::*};
use crate::{proc::Typifier, Arena, BinaryOperator, Binding, Block, Constant,
ConstantInner, EntryPoint, Expression, Function, GlobalVariable, Handle, Interpolation,
LocalVariable, MemberOrigin, ScalarKind, Statement, StorageAccess,
StorageClass, StructMember, Type, TypeInner};
ConstantInner, EntryPoint, Expression, FallThrough, FastHashMap, Function, GlobalVariable, Handle, Interpolation,
LocalVariable, MemberOrigin, SampleLevel, ScalarKind, Statement, StorageAccess,
StorageClass, StructMember, Type, TypeInner, UnaryOperator};
}
%token #[derive(Debug)] #[cfg_attr(test, derive(PartialEq))] pub enum Token {};
%parser pub struct Parser<'a> {};
@ -55,6 +55,13 @@ pomelo! {
%type expression_statement Statement;
%type declaration_statement Statement;
%type jump_statement Statement;
%type iteration_statement Statement;
%type selection_statement Statement;
%type switch_statement_list Vec<(Option<i32>, Block, Option<FallThrough>)>;
%type switch_statement (Option<i32>, Block, Option<FallThrough>);
%type for_init_statement Statement;
%type for_rest_statement (Option<ExpressionRule>, Option<ExpressionRule>);
%type condition_opt Option<ExpressionRule>;
// expressions
%type unary_expression ExpressionRule;
@ -90,7 +97,7 @@ pomelo! {
%type initializer ExpressionRule;
// decalartions
// declarations
%type declaration VarDeclaration;
%type init_declarator_list VarDeclaration;
%type single_declaration VarDeclaration;
@ -115,6 +122,9 @@ pomelo! {
%type TypeName Type;
// precedence
%right Else;
root ::= version_pragma translation_unit;
version_pragma ::= Version IntConstant(V) Identifier?(P) {
match V.1 {
@ -140,9 +150,7 @@ pomelo! {
let var = extra.lookup_variable(&v.1)?;
match var {
Some(expression) => {
ExpressionRule::from_expression(
expression
)
ExpressionRule::from_expression(expression)
},
None => {
return Err(ErrorKind::UnknownVariable(v.0, v.1));
@ -220,7 +228,7 @@ pomelo! {
postfix_expression ::= postfix_expression(e) Dot Identifier(i) /* FieldSelection in spec */ {
//TODO: how will this work as l-value?
let expression = extra.field_selection(e.expression, &*i.1, i.0)?;
ExpressionRule { expression, statements: e.statements }
ExpressionRule { expression, statements: e.statements, sampler: None }
}
postfix_expression ::= postfix_expression(pe) IncOp {
//TODO
@ -234,17 +242,49 @@ pomelo! {
integer_expression ::= expression;
function_call ::= function_call_or_method(fc) {
if let FunctionCallKind::TypeConstructor(ty) = fc.kind {
let h = extra.context.expressions.append(Expression::Compose{
ty,
components: fc.args,
});
ExpressionRule{
expression: h,
statements: fc.statements,
match fc.kind {
FunctionCallKind::TypeConstructor(ty) => {
let h = extra.context.expressions.append(Expression::Compose {
ty,
components: fc.args.iter().map(|a| a.expression).collect(),
});
ExpressionRule {
expression: h,
statements: fc.args.into_iter().map(|a| a.statements).flatten().collect(),
sampler: None
}
}
FunctionCallKind::Function(name) => {
match name.as_str() {
"sampler2D" => {
//TODO: check args len
ExpressionRule{
expression: fc.args[0].expression,
sampler: Some(fc.args[1].expression),
statements: fc.args.into_iter().map(|a| a.statements).flatten().collect(),
}
}
"texture" => {
//TODO: check args len
if let Some(sampler) = fc.args[0].sampler {
ExpressionRule{
expression: extra.context.expressions.append(Expression::ImageSample {
image: fc.args[0].expression,
sampler,
coordinate: fc.args[1].expression,
level: SampleLevel::Auto,
depth_ref: None,
}),
sampler: None,
statements: fc.args.into_iter().map(|a| a.statements).flatten().collect(),
}
} else {
return Err(ErrorKind::SemanticError("Bad call to texture"));
}
}
_ => { return Err(ErrorKind::NotImplemented("Function call")); }
}
}
} else {
return Err(ErrorKind::NotImplemented("Function call"));
}
}
function_call_or_method ::= function_call_generic;
@ -259,26 +299,22 @@ pomelo! {
}
function_call_header_no_parameters ::= function_call_header;
function_call_header_with_parameters ::= function_call_header(mut h) assignment_expression(ae) {
h.args.push(ae.expression);
h.statements.extend(ae.statements);
h.args.push(ae);
h
}
function_call_header_with_parameters ::= function_call_header_with_parameters(mut h) Comma assignment_expression(ae) {
h.args.push(ae.expression);
h.statements.extend(ae.statements);
h.args.push(ae);
h
}
function_call_header ::= function_identifier(i) LeftParen {
FunctionCall {
kind: i,
args: vec![],
statements: vec![],
}
}
// Grammar Note: Constructors look like functions, but lexical analysis recognized most of them as
// keywords. They are now recognized through “type_specifier”.
// Methods (.length), subroutine array calls, and identifiers are recognized through postfix_expression.
function_identifier ::= type_specifier(t) {
if let Some(ty) = t {
FunctionCallKind::TypeConstructor(ty)
@ -286,10 +322,19 @@ pomelo! {
return Err(ErrorKind::NotImplemented("bad type ctor"))
}
}
function_identifier ::= postfix_expression(e) {
FunctionCallKind::Function(e.expression)
//TODO
// Methods (.length), subroutine array calls, and identifiers are recognized through postfix_expression.
// function_identifier ::= postfix_expression(e) {
// FunctionCallKind::Function(e.expression)
// }
// Simplification of above
function_identifier ::= Identifier(i) {
FunctionCallKind::Function(i.1)
}
unary_expression ::= postfix_expression;
unary_expression ::= IncOp unary_expression {
@ -311,74 +356,76 @@ pomelo! {
unary_operator ::= Tilde;
multiplicative_expression ::= unary_expression;
multiplicative_expression ::= multiplicative_expression(left) Star unary_expression(right) {
extra.binary_expr(BinaryOperator::Multiply, left, right)
extra.binary_expr(BinaryOperator::Multiply, &left, &right)
}
multiplicative_expression ::= multiplicative_expression(left) Slash unary_expression(right) {
extra.binary_expr(BinaryOperator::Divide, left, right)
extra.binary_expr(BinaryOperator::Divide, &left, &right)
}
multiplicative_expression ::= multiplicative_expression(left) Percent unary_expression(right) {
extra.binary_expr(BinaryOperator::Modulo, left, right)
extra.binary_expr(BinaryOperator::Modulo, &left, &right)
}
additive_expression ::= multiplicative_expression;
additive_expression ::= additive_expression(left) Plus multiplicative_expression(right) {
extra.binary_expr(BinaryOperator::Add, left, right)
extra.binary_expr(BinaryOperator::Add, &left, &right)
}
additive_expression ::= additive_expression(left) Dash multiplicative_expression(right) {
extra.binary_expr(BinaryOperator::Subtract, left, right)
extra.binary_expr(BinaryOperator::Subtract, &left, &right)
}
shift_expression ::= additive_expression;
shift_expression ::= shift_expression(left) LeftOp additive_expression(right) {
extra.binary_expr(BinaryOperator::ShiftLeftLogical, left, right)
extra.binary_expr(BinaryOperator::ShiftLeft, &left, &right)
}
shift_expression ::= shift_expression(left) RightOp additive_expression(right) {
//TODO: when to use ShiftRightArithmetic
extra.binary_expr(BinaryOperator::ShiftRightLogical, left, right)
extra.binary_expr(BinaryOperator::ShiftRight, &left, &right)
}
relational_expression ::= shift_expression;
relational_expression ::= relational_expression(left) LeftAngle shift_expression(right) {
extra.binary_expr(BinaryOperator::Less, left, right)
extra.binary_expr(BinaryOperator::Less, &left, &right)
}
relational_expression ::= relational_expression(left) RightAngle shift_expression(right) {
extra.binary_expr(BinaryOperator::Greater, left, right)
extra.binary_expr(BinaryOperator::Greater, &left, &right)
}
relational_expression ::= relational_expression(left) LeOp shift_expression(right) {
extra.binary_expr(BinaryOperator::LessEqual, left, right)
extra.binary_expr(BinaryOperator::LessEqual, &left, &right)
}
relational_expression ::= relational_expression(left) GeOp shift_expression(right) {
extra.binary_expr(BinaryOperator::GreaterEqual, left, right)
extra.binary_expr(BinaryOperator::GreaterEqual, &left, &right)
}
equality_expression ::= relational_expression;
equality_expression ::= equality_expression(left) EqOp relational_expression(right) {
extra.binary_expr(BinaryOperator::Equal, left, right)
extra.binary_expr(BinaryOperator::Equal, &left, &right)
}
equality_expression ::= equality_expression(left) NeOp relational_expression(right) {
extra.binary_expr(BinaryOperator::NotEqual, left, right)
extra.binary_expr(BinaryOperator::NotEqual, &left, &right)
}
and_expression ::= equality_expression;
and_expression ::= and_expression(left) Ampersand equality_expression(right) {
extra.binary_expr(BinaryOperator::And, left, right)
extra.binary_expr(BinaryOperator::And, &left, &right)
}
exclusive_or_expression ::= and_expression;
exclusive_or_expression ::= exclusive_or_expression(left) Caret and_expression(right) {
extra.binary_expr(BinaryOperator::ExclusiveOr, left, right)
extra.binary_expr(BinaryOperator::ExclusiveOr, &left, &right)
}
inclusive_or_expression ::= exclusive_or_expression;
inclusive_or_expression ::= inclusive_or_expression(left) VerticalBar exclusive_or_expression(right) {
extra.binary_expr(BinaryOperator::InclusiveOr, left, right)
extra.binary_expr(BinaryOperator::InclusiveOr, &left, &right)
}
logical_and_expression ::= inclusive_or_expression;
logical_and_expression ::= logical_and_expression(left) AndOp inclusive_or_expression(right) {
extra.binary_expr(BinaryOperator::LogicalAnd, left, right)
extra.binary_expr(BinaryOperator::LogicalAnd, &left, &right)
}
logical_xor_expression ::= logical_and_expression;
logical_xor_expression ::= logical_xor_expression(left) XorOp logical_and_expression(right) {
return Err(ErrorKind::NotImplemented("logical xor"))
//TODO: naga doesn't have BinaryOperator::LogicalXor
// extra.context.expressions.append(Expression::Binary{op: BinaryOperator::LogicalXor, left, right})
let exp1 = extra.binary_expr(BinaryOperator::LogicalOr, &left, &right);
let exp2 = {
let tmp = extra.binary_expr(BinaryOperator::LogicalAnd, &left, &right).expression;
ExpressionRule::from_expression(extra.context.expressions.append(Expression::Unary { op: UnaryOperator::Not, expr: tmp }))
};
extra.binary_expr(BinaryOperator::LogicalAnd, &exp1, &exp2)
}
logical_or_expression ::= logical_xor_expression;
logical_or_expression ::= logical_or_expression(left) OrOp logical_xor_expression(right) {
extra.binary_expr(BinaryOperator::LogicalOr, left, right)
extra.binary_expr(BinaryOperator::LogicalOr, &left, &right)
}
conditional_expression ::= logical_or_expression;
@ -389,17 +436,29 @@ pomelo! {
assignment_expression ::= conditional_expression;
assignment_expression ::= unary_expression(mut pointer) assignment_operator(op) assignment_expression(value) {
pointer.statements.extend(value.statements);
match op {
BinaryOperator::Equal => {
pointer.statements.extend(value.statements);
pointer.statements.push(Statement::Store{
pointer: pointer.expression,
value: value.expression
});
pointer
},
//TODO: op != Equal
_ => {return Err(ErrorKind::NotImplemented("assign op"))}
_ => {
let h = extra.context.expressions.append(
Expression::Binary{
op,
left: pointer.expression,
right: value.expression,
}
);
pointer.statements.push(Statement::Store{
pointer: pointer.expression,
value: h,
});
pointer
}
}
}
@ -422,10 +481,10 @@ pomelo! {
BinaryOperator::Subtract
}
assignment_operator ::= LeftAssign {
BinaryOperator::ShiftLeftLogical
BinaryOperator::ShiftLeft
}
assignment_operator ::= RightAssign {
BinaryOperator::ShiftRightLogical
BinaryOperator::ShiftRight
}
assignment_operator ::= AndAssign {
BinaryOperator::And
@ -440,9 +499,10 @@ pomelo! {
expression ::= assignment_expression;
expression ::= expression(e) Comma assignment_expression(mut ae) {
ae.statements.extend(e.statements);
ExpressionRule{
ExpressionRule {
expression: e.expression,
statements: ae.statements,
sampler: None,
}
}
@ -462,7 +522,7 @@ pomelo! {
declaration ::= type_qualifier(t) Identifier(i) LeftBrace
struct_declaration_list(sdl) RightBrace Semicolon {
VarDeclaration{
VarDeclaration {
type_qualifiers: t,
ids_initializers: vec![(None, None)],
ty: extra.module.types.fetch_or_append(Type{
@ -476,7 +536,7 @@ pomelo! {
declaration ::= type_qualifier(t) Identifier(i1) LeftBrace
struct_declaration_list(sdl) RightBrace Identifier(i2) Semicolon {
VarDeclaration{
VarDeclaration {
type_qualifiers: t,
ids_initializers: vec![(Some(i2.1), None)],
ty: extra.module.types.fetch_or_append(Type{
@ -506,7 +566,7 @@ pomelo! {
single_declaration ::= fully_specified_type(t) {
let ty = t.1.ok_or(ErrorKind::SemanticError("Empty type for declaration"))?;
VarDeclaration{
VarDeclaration {
type_qualifiers: t.0,
ids_initializers: vec![],
ty,
@ -515,7 +575,7 @@ pomelo! {
single_declaration ::= fully_specified_type(t) Identifier(i) {
let ty = t.1.ok_or(ErrorKind::SemanticError("Empty type for declaration"))?;
VarDeclaration{
VarDeclaration {
type_qualifiers: t.0,
ids_initializers: vec![(Some(i.1), None)],
ty,
@ -526,7 +586,7 @@ pomelo! {
single_declaration ::= fully_specified_type(t) Identifier(i) Equal initializer(init) {
let ty = t.1.ok_or(ErrorKind::SemanticError("Empty type for declaration"))?;
VarDeclaration{
VarDeclaration {
type_qualifiers: t.0,
ids_initializers: vec![(Some(i.1), Some(init))],
ty,
@ -597,9 +657,7 @@ pomelo! {
// single_type_qualifier ::= invariant_qualifier;
// single_type_qualifier ::= precise_qualifier;
storage_qualifier ::= Const {
StorageClass::Constant
}
// storage_qualifier ::= Const
// storage_qualifier ::= InOut;
storage_qualifier ::= In {
StorageClass::Input
@ -658,7 +716,7 @@ pomelo! {
struct_declaration ::= type_specifier(t) struct_declarator_list(sdl) Semicolon {
if let Some(ty) = t {
sdl.iter().map(|name| StructMember{
sdl.iter().map(|name| StructMember {
name: Some(name.clone()),
origin: MemberOrigin::Empty,
ty,
@ -702,18 +760,33 @@ pomelo! {
return Err(ErrorKind::VariableAlreadyDeclared(id))
}
}
let mut init_exp: Option<Handle<Expression>> = None;
let localVar = extra.context.local_variables.append(
LocalVariable {
name: Some(id.clone()),
ty: d.ty,
init: initializer.map(|i| {
statements.extend(i.statements);
i.expression
}),
if let Expression::Constant(constant) = extra.context.expressions[i.expression] {
Some(constant)
} else {
init_exp = Some(i.expression);
None
}
}).flatten(),
}
);
let exp = extra.context.expressions.append(Expression::LocalVariable(localVar));
extra.context.add_local_var(id, exp);
if let Some(value) = init_exp {
statements.push(
Statement::Store {
pointer: exp,
value,
}
);
}
}
match statements.len() {
1 => statements.remove(0),
@ -727,14 +800,138 @@ pomelo! {
}
statement ::= simple_statement;
// Grammar Note: labeled statements for SWITCH only; 'goto' is not supported.
simple_statement ::= declaration_statement;
simple_statement ::= expression_statement;
//simple_statement ::= selection_statement;
//simple_statement ::= switch_statement;
//simple_statement ::= case_label;
//simple_statement ::= iteration_statement;
simple_statement ::= selection_statement;
simple_statement ::= jump_statement;
simple_statement ::= iteration_statement;
selection_statement ::= If LeftParen expression(e) RightParen statement(s1) Else statement(s2) {
Statement::If {
condition: e.expression,
accept: vec![s1],
reject: vec![s2],
}
}
selection_statement ::= If LeftParen expression(e) RightParen statement(s) [Else] {
Statement::If {
condition: e.expression,
accept: vec![s],
reject: vec![],
}
}
selection_statement ::= Switch LeftParen expression(e) RightParen LeftBrace switch_statement_list(ls) RightBrace {
let mut default = Vec::new();
let mut cases = FastHashMap::default();
for (v, s, ft) in ls {
if let Some(v) = v {
cases.insert(v, (s, ft));
} else {
default.extend_from_slice(&s);
}
}
Statement::Switch {
selector: e.expression,
cases,
default,
}
}
switch_statement_list ::= {
vec![]
}
switch_statement_list ::= switch_statement_list(mut ssl) switch_statement((v, sl, ft)) {
ssl.push((v, sl, ft));
ssl
}
switch_statement ::= Case IntConstant(v) Colon statement_list(sl) {
let fallthrough = match sl.last() {
Some(Statement::Break) => None,
_ => Some(FallThrough),
};
(Some(v.1 as i32), sl, fallthrough)
}
switch_statement ::= Default Colon statement_list(sl) {
let fallthrough = match sl.last() {
Some(Statement::Break) => Some(FallThrough),
_ => None,
};
(None, sl, fallthrough)
}
iteration_statement ::= While LeftParen expression(e) RightParen compound_statement_no_new_scope(sl) {
let mut body = Vec::with_capacity(sl.len() + 1);
body.push(
Statement::If {
condition: e.expression,
accept: vec![Statement::Break],
reject: vec![],
}
);
body.extend_from_slice(&sl);
Statement::Loop {
body,
continuing: vec![],
}
}
iteration_statement ::= Do compound_statement(sl) While LeftParen expression(e) RightParen {
let mut body = sl;
body.push(
Statement::If {
condition: e.expression,
accept: vec![Statement::Break],
reject: vec![],
}
);
Statement::Loop {
body,
continuing: vec![],
}
}
iteration_statement ::= For LeftParen for_init_statement(s_init) for_rest_statement((cond_e, loop_e)) RightParen compound_statement_no_new_scope(sl) {
let mut body = Vec::with_capacity(sl.len() + 2);
if let Some(cond_e) = cond_e {
body.push(
Statement::If {
condition: cond_e.expression,
accept: vec![Statement::Break],
reject: vec![],
}
);
}
body.extend_from_slice(&sl);
if let Some(loop_e) = loop_e {
body.extend_from_slice(&loop_e.statements);
}
Statement::Block(vec![
s_init,
Statement::Loop {
body,
continuing: vec![],
}
])
}
for_init_statement ::= expression_statement;
for_init_statement ::= declaration_statement;
for_rest_statement ::= condition_opt(c) Semicolon {
(c, None)
}
for_rest_statement ::= condition_opt(c) Semicolon expression(e) {
(c, Some(e))
}
condition_opt ::= {
None
}
condition_opt ::= conditional_expression(c) {
Some(c)
}
compound_statement ::= LeftBrace RightBrace {
vec![]
@ -810,7 +1007,7 @@ pomelo! {
function_header ::= fully_specified_type(t) Identifier(n) LeftParen {
Function {
name: Some(n.1),
parameter_types: vec![],
arguments: vec![],
return_type: t.1,
global_usage: vec![],
local_variables: Arena::<LocalVariable>::new(),
@ -826,7 +1023,7 @@ pomelo! {
Statement::Break
}
jump_statement ::= Return Semicolon {
Statement::Return{ value: None }
Statement::Return { value: None }
}
jump_statement ::= Return expression(mut e) Semicolon {
let ret = Statement::Return{ value: Some(e.expression) };
@ -884,6 +1081,7 @@ pomelo! {
class,
binding: binding.clone(),
ty: d.ty,
init: None,
interpolation,
storage_access: StorageAccess::empty(), //TODO
},
@ -894,12 +1092,17 @@ pomelo! {
}
}
function_definition ::= function_prototype(mut f) compound_statement_no_new_scope(cs) {
function_definition ::= function_prototype(mut f) compound_statement_no_new_scope(mut cs) {
std::mem::swap(&mut f.expressions, &mut extra.context.expressions);
std::mem::swap(&mut f.local_variables, &mut extra.context.local_variables);
extra.context.clear_scopes();
extra.context.lookup_global_var_exps.clear();
extra.context.typifier = Typifier::new();
// make sure function ends with return
match cs.last() {
Some(Statement::Return {..}) => {}
_ => {cs.push(Statement::Return { value:None });}
}
f.body = cs;
f.fill_global_use(&extra.module.global_variables);
f

View file

@ -78,3 +78,105 @@ fn version() {
"(450, Core)"
);
}
#[test]
fn control_flow() {
let _program = parse_program(
r#"
# version 450
void main() {
if (true) {
return 1;
} else {
return 2;
}
}
"#,
ShaderStage::Vertex,
)
.unwrap();
let _program = parse_program(
r#"
# version 450
void main() {
if (true) {
return 1;
}
}
"#,
ShaderStage::Vertex,
)
.unwrap();
let _program = parse_program(
r#"
# version 450
void main() {
int x;
int y = 3;
switch (5) {
case 2:
x = 2;
case 5:
x = 5;
y = 2;
break;
default:
x = 0;
}
}
"#,
ShaderStage::Vertex,
)
.unwrap();
let _program = parse_program(
r#"
# version 450
void main() {
int x = 0;
while(x < 5) {
x = x + 1;
}
do {
x = x - 1;
} while(x >= 4)
}
"#,
ShaderStage::Vertex,
)
.unwrap();
let _program = parse_program(
r#"
# version 450
void main() {
int x = 0;
for(int i = 0; i < 10;) {
x = x + 2;
}
return x;
}
"#,
ShaderStage::Vertex,
)
.unwrap();
}
#[test]
fn textures() {
let _program = parse_program(
r#"
#version 450
layout(location = 0) in vec2 v_uv;
layout(location = 0) out vec4 o_color;
layout(set = 1, binding = 1) uniform texture2D tex;
layout(set = 1, binding = 2) uniform sampler tex_sampler;
void main() {
o_color = texture(sampler2D(tex, tex_sampler), v_uv);
}
"#,
ShaderStage::Fragment,
)
.unwrap();
}

View file

@ -0,0 +1,152 @@
use crate::FastHashMap;
use thiserror::Error;
#[derive(Clone, Debug, Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum Error {
#[error("unmatched else")]
UnmatchedElse,
#[error("unmatched endif")]
UnmatchedEndif,
#[error("missing macro name")]
MissingMacro,
}
#[derive(Clone, Debug)]
pub struct IfState {
true_branch: bool,
else_seen: bool,
}
#[derive(Clone, Debug)]
pub struct LinePreProcessor {
pub defines: FastHashMap<String, String>,
if_stack: Vec<IfState>,
inside_comment: bool,
in_preprocess: bool,
}
impl LinePreProcessor {
pub fn new() -> Self {
LinePreProcessor {
defines: FastHashMap::default(),
if_stack: vec![],
inside_comment: false,
in_preprocess: false,
}
}
fn subst_defines(&self, input: &str) -> String {
//TODO: don't subst in commments, strings literals?
self.defines
.iter()
.fold(input.to_string(), |acc, (k, v)| acc.replace(k, v))
}
pub fn process_line(&mut self, line: &str) -> Result<Option<String>, Error> {
let mut skip = !self.if_stack.last().map(|i| i.true_branch).unwrap_or(true);
let mut inside_comment = self.inside_comment;
let mut in_preprocess = inside_comment && self.in_preprocess;
// single-line comment
let mut processed = line;
if let Some(pos) = line.find("//") {
processed = line.split_at(pos).0;
}
// multi-line comment
let mut processed_string: String;
loop {
if inside_comment {
if let Some(pos) = processed.find("*/") {
processed = processed.split_at(pos + 2).1;
inside_comment = false;
self.inside_comment = false;
continue;
}
} else if let Some(pos) = processed.find("/*") {
if let Some(end_pos) = processed[pos + 2..].find("*/") {
// comment ends during this line
processed_string = processed.to_string();
processed_string.replace_range(pos..pos + end_pos + 4, "");
processed = &processed_string;
} else {
processed = processed.split_at(pos).0;
inside_comment = true;
}
continue;
}
break;
}
// strip leading whitespace
processed = processed.trim_start();
if processed.starts_with('#') && !self.inside_comment {
let mut iter = processed[1..]
.trim_start()
.splitn(2, |c: char| c.is_whitespace());
if let Some(directive) = iter.next() {
skip = true;
in_preprocess = true;
match directive {
"version" => {
skip = false;
}
"define" => {
let rest = iter.next().ok_or(Error::MissingMacro)?;
let pos = rest
.find(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '(')
.unwrap_or_else(|| rest.len());
let (key, mut value) = rest.split_at(pos);
value = value.trim();
self.defines.insert(key.into(), self.subst_defines(value));
}
"undef" => {
let rest = iter.next().ok_or(Error::MissingMacro)?;
let key = rest.trim();
self.defines.remove(key);
}
"ifdef" => {
let rest = iter.next().ok_or(Error::MissingMacro)?;
let key = rest.trim();
self.if_stack.push(IfState {
true_branch: self.defines.contains_key(key),
else_seen: false,
});
}
"ifndef" => {
let rest = iter.next().ok_or(Error::MissingMacro)?;
let key = rest.trim();
self.if_stack.push(IfState {
true_branch: !self.defines.contains_key(key),
else_seen: false,
});
}
"else" => {
let if_state = self.if_stack.last_mut().ok_or(Error::UnmatchedElse)?;
if !if_state.else_seen {
// this is first else
if_state.true_branch = !if_state.true_branch;
if_state.else_seen = true;
} else {
return Err(Error::UnmatchedElse);
}
}
"endif" => {
self.if_stack.pop().ok_or(Error::UnmatchedEndif)?;
}
_ => {}
}
}
}
let res = if !skip && !self.inside_comment {
Ok(Some(self.subst_defines(&line)))
} else {
Ok(if in_preprocess && !self.in_preprocess {
Some("".to_string())
} else {
None
})
};
self.in_preprocess = in_preprocess || skip;
self.inside_comment = inside_comment;
res
}
}

View file

@ -0,0 +1,218 @@
use super::preprocess::{Error, LinePreProcessor};
use std::{iter::Enumerate, str::Lines};
#[derive(Clone, Debug)]
pub struct PreProcessor<'a> {
lines: Enumerate<Lines<'a>>,
input: String,
line: usize,
offset: usize,
line_pp: LinePreProcessor,
}
impl<'a> PreProcessor<'a> {
pub fn new(input: &'a str) -> Self {
let mut lexer = PreProcessor {
lines: input.lines().enumerate(),
input: "".to_string(),
line: 0,
offset: 0,
line_pp: LinePreProcessor::new(),
};
lexer.next_line();
lexer
}
fn next_line(&mut self) -> bool {
if let Some((line, input)) = self.lines.next() {
let mut input = String::from(input);
while input.ends_with('\\') {
if let Some((_, next)) = self.lines.next() {
input.pop();
input.push_str(next);
} else {
break;
}
}
self.input = input;
self.line = line;
self.offset = 0;
true
} else {
false
}
}
pub fn process(&mut self) -> Result<String, Error> {
let mut res = String::new();
loop {
let line = &self.line_pp.process_line(&self.input)?;
if let Some(line) = line {
res.push_str(line);
}
if !self.next_line() {
break;
}
if line.is_some() {
res.push_str("\n");
}
}
Ok(res)
}
}
#[test]
fn preprocess() {
// line continuation
let mut pp = PreProcessor::new(
"void main my_\
func",
);
assert_eq!(pp.process().unwrap(), "void main my_func");
// preserve #version
let mut pp = PreProcessor::new(
"#version 450 core\n\
void main()",
);
assert_eq!(pp.process().unwrap(), "#version 450 core\nvoid main()");
// simple define
let mut pp = PreProcessor::new(
"#define FOO 42 \n\
fun=FOO",
);
assert_eq!(pp.process().unwrap(), "\nfun=42");
// ifdef with else
let mut pp = PreProcessor::new(
"#define FOO\n\
#ifdef FOO\n\
foo=42\n\
#endif\n\
some=17\n\
#ifdef BAR\n\
bar=88\n\
#else\n\
mm=49\n\
#endif\n\
done=1",
);
assert_eq!(
pp.process().unwrap(),
"\n\
foo=42\n\
\n\
some=17\n\
\n\
mm=49\n\
\n\
done=1"
);
// nested ifdef/ifndef
let mut pp = PreProcessor::new(
"#define FOO\n\
#define BOO\n\
#ifdef FOO\n\
foo=42\n\
#ifdef BOO\n\
boo=44\n\
#endif\n\
ifd=0\n\
#ifndef XYZ\n\
nxyz=8\n\
#endif\n\
#endif\n\
some=17\n\
#ifdef BAR\n\
bar=88\n\
#else\n\
mm=49\n\
#endif\n\
done=1",
);
assert_eq!(
pp.process().unwrap(),
"\n\
foo=42\n\
\n\
boo=44\n\
\n\
ifd=0\n\
\n\
nxyz=8\n\
\n\
some=17\n\
\n\
mm=49\n\
\n\
done=1"
);
// undef
let mut pp = PreProcessor::new(
"#define FOO\n\
#ifdef FOO\n\
foo=42\n\
#endif\n\
some=17\n\
#undef FOO\n\
#ifdef FOO\n\
foo=88\n\
#else\n\
nofoo=66\n\
#endif\n\
done=1",
);
assert_eq!(
pp.process().unwrap(),
"\n\
foo=42\n\
\n\
some=17\n\
\n\
nofoo=66\n\
\n\
done=1"
);
// single-line comment
let mut pp = PreProcessor::new(
"#define FOO 42//1234\n\
fun=FOO",
);
assert_eq!(pp.process().unwrap(), "\nfun=42");
// multi-line comments
let mut pp = PreProcessor::new(
"#define FOO 52/*/1234\n\
#define FOO 88\n\
end of comment*/ /* one more comment */ #define FOO 56\n\
fun=FOO",
);
assert_eq!(pp.process().unwrap(), "\nfun=56");
// unmatched endif
let mut pp = PreProcessor::new(
"#ifdef FOO\n\
foo=42\n\
#endif\n\
#endif",
);
assert_eq!(pp.process(), Err(Error::UnmatchedEndif));
// unmatched else
let mut pp = PreProcessor::new(
"#ifdef FOO\n\
foo=42\n\
#else\n\
bar=88\n\
#else\n\
bad=true\n\
#endif",
);
assert_eq!(pp.process(), Err(Error::UnmatchedElse));
}

View file

@ -37,6 +37,21 @@ pub fn parse_type(type_name: &str) -> Option<Type> {
width: 4,
},
}),
"texture2D" => Some(Type {
name: None,
inner: TypeInner::Image {
dim: crate::ImageDimension::D2,
arrayed: false,
class: crate::ImageClass::Sampled {
kind: ScalarKind::Float,
multi: false,
},
},
}),
"sampler" => Some(Type {
name: None,
inner: TypeInner::Sampler { comparison: false },
}),
word => {
fn kind_width_parse(ty: &str) -> Option<(ScalarKind, u8)> {
Some(match ty {

View file

@ -38,6 +38,7 @@ impl Program {
width: 4,
},
}),
init: None,
interpolation: None,
storage_access: StorageAccess::empty(),
});
@ -72,6 +73,7 @@ impl Program {
width: 4,
},
}),
init: None,
interpolation: None,
storage_access: StorageAccess::empty(),
});

View file

@ -43,21 +43,6 @@ pub fn map_vector_size(word: spirv::Word) -> Result<crate::VectorSize, Error> {
}
}
pub fn map_storage_class(word: spirv::Word) -> Result<crate::StorageClass, Error> {
use spirv::StorageClass as Sc;
match Sc::from_u32(word) {
Some(Sc::UniformConstant) => Ok(crate::StorageClass::Constant),
Some(Sc::Function) => Ok(crate::StorageClass::Function),
Some(Sc::Input) => Ok(crate::StorageClass::Input),
Some(Sc::Output) => Ok(crate::StorageClass::Output),
Some(Sc::Private) => Ok(crate::StorageClass::Private),
Some(Sc::StorageBuffer) => Ok(crate::StorageClass::StorageBuffer),
Some(Sc::Uniform) => Ok(crate::StorageClass::Uniform),
Some(Sc::Workgroup) => Ok(crate::StorageClass::WorkGroup),
_ => Err(Error::UnsupportedStorageClass(word)),
}
}
pub fn map_image_dim(word: spirv::Word) -> Result<crate::ImageDimension, Error> {
use spirv::Dim as D;
match D::from_u32(word) {

View file

@ -46,9 +46,11 @@ pub enum Error {
InvalidAsType(Handle<crate::Type>),
InconsistentComparisonSampling(Handle<crate::Type>),
WrongFunctionResultType(spirv::Word),
WrongFunctionParameterType(spirv::Word),
WrongFunctionArgumentType(spirv::Word),
MissingDecoration(spirv::Decoration),
BadString,
IncompleteData,
InvalidTerminator,
InvalidEdgeClassification,
UnexpectedComparisonType(Handle<crate::Type>),
}

View file

@ -152,6 +152,12 @@ impl FlowGraph {
let (node_source_index, node_target_index) =
self.flow.edge_endpoints(edge_index).unwrap();
if self.flow[node_source_index].ty == Some(ControlFlowNodeType::Header)
|| self.flow[node_source_index].ty == Some(ControlFlowNodeType::Loop)
{
continue;
}
// Back
if self.flow[node_target_index].ty == Some(ControlFlowNodeType::Loop)
&& self.flow[node_source_index].id > self.flow[node_target_index].id
@ -219,10 +225,8 @@ impl FlowGraph {
node_index: BlockNodeIndex,
stop_node_index: Option<BlockNodeIndex>,
) -> Result<crate::Block, Error> {
if let Some(stop_node_index) = stop_node_index {
if stop_node_index == node_index {
return Ok(vec![]);
}
if stop_node_index == Some(node_index) {
return Ok(vec![]);
}
let node = &self.flow[node_index];
@ -246,7 +250,7 @@ impl FlowGraph {
accept: self.naga_traverse(true_node_index, Some(merge_node_index))?,
reject: self.naga_traverse(false_node_index, Some(merge_node_index))?,
});
result.extend(self.naga_traverse(merge_node_index, None)?);
result.extend(self.naga_traverse(merge_node_index, stop_node_index)?);
} else {
result.push(crate::Statement::If {
condition,
@ -254,7 +258,7 @@ impl FlowGraph {
self.block_to_node[&true_id],
Some(merge_node_index),
)?,
reject: self.naga_traverse(merge_node_index, None)?,
reject: self.naga_traverse(merge_node_index, stop_node_index)?,
});
}
@ -305,7 +309,7 @@ impl FlowGraph {
.naga_traverse(self.block_to_node[&default], Some(merge_node_index))?,
});
result.extend(self.naga_traverse(merge_node_index, None)?);
result.extend(self.naga_traverse(merge_node_index, stop_node_index)?);
Ok(result)
}
@ -323,7 +327,7 @@ impl FlowGraph {
self.flow[continue_edge.target()].block.clone()
};
let mut body: crate::Block = node.block.clone();
let mut body = node.block.clone();
match node.terminator {
Terminator::BranchConditional {
condition,
@ -342,7 +346,10 @@ impl FlowGraph {
_ => return Err(Error::InvalidTerminator),
};
Ok(vec![crate::Statement::Loop { body, continuing }])
let mut result = vec![crate::Statement::Loop { body, continuing }];
result.extend(self.naga_traverse(merge_node_index, stop_node_index)?);
Ok(result)
}
Some(ControlFlowNodeType::Break) => {
let mut result = node.block.clone();
@ -351,25 +358,52 @@ impl FlowGraph {
condition,
true_id,
false_id,
} => result.push(crate::Statement::If {
condition,
accept: self
.naga_traverse(self.block_to_node[&true_id], stop_node_index)?,
reject: self
.naga_traverse(self.block_to_node[&false_id], stop_node_index)?,
}),
} => {
let true_node_id = self.block_to_node[&true_id];
let false_node_id = self.block_to_node[&false_id];
let true_edge =
self.flow[self.flow.find_edge(node_index, true_node_id).unwrap()];
let false_edge =
self.flow[self.flow.find_edge(node_index, false_node_id).unwrap()];
if true_edge == ControlFlowEdgeType::LoopBreak {
result.push(crate::Statement::If {
condition,
accept: vec![crate::Statement::Break],
reject: self.naga_traverse(false_node_id, stop_node_index)?,
});
} else if false_edge == ControlFlowEdgeType::LoopBreak {
result.push(crate::Statement::If {
condition,
accept: self.naga_traverse(true_node_id, stop_node_index)?,
reject: vec![crate::Statement::Break],
});
} else {
return Err(Error::InvalidEdgeClassification);
}
}
Terminator::Branch { .. } => {
result.push(crate::Statement::Break);
}
_ => return Err(Error::InvalidTerminator),
};
Ok(result)
}
Some(ControlFlowNodeType::Continue) => {
let back_block = match node.terminator {
Terminator::Branch { target_id } => {
self.naga_traverse(self.block_to_node[&target_id], None)?
}
_ => return Err(Error::InvalidTerminator),
};
let mut result = node.block.clone();
result.extend(back_block);
result.push(crate::Statement::Continue);
Ok(result)
}
Some(ControlFlowNodeType::Back) | Some(ControlFlowNodeType::Merge) => {
Ok(node.block.clone())
}
Some(ControlFlowNodeType::Back) => Ok(node.block.clone()),
Some(ControlFlowNodeType::Kill) => {
let mut result = node.block.clone();
result.push(crate::Statement::Kill);
@ -384,7 +418,7 @@ impl FlowGraph {
result.push(crate::Statement::Return { value });
Ok(result)
}
None => match node.terminator {
Some(ControlFlowNodeType::Merge) | None => match node.terminator {
Terminator::Branch { target_id } => {
let mut result = node.block.clone();
result.extend(
@ -401,7 +435,7 @@ impl FlowGraph {
pub(super) fn to_graphviz(&self) -> Result<String, std::fmt::Error> {
let mut output = String::new();
output += "digraph ControlFlowGraph {";
output += "digraph ControlFlowGraph {\n";
for node_index in self.flow.node_indices() {
let node = &self.flow[node_index];
@ -419,10 +453,16 @@ impl FlowGraph {
let target = edge.target();
let style = match edge.weight {
ControlFlowEdgeType::Forward => "",
ControlFlowEdgeType::ForwardMerge => "style=dotted",
ControlFlowEdgeType::ForwardContinue => "color=green",
ControlFlowEdgeType::Back => "style=dashed",
ControlFlowEdgeType::LoopBreak => "color=yellow",
ControlFlowEdgeType::LoopContinue => "color=green",
ControlFlowEdgeType::IfTrue => "color=blue",
ControlFlowEdgeType::IfFalse => "color=red",
ControlFlowEdgeType::ForwardMerge => "style=dotted",
_ => "",
ControlFlowEdgeType::SwitchBreak => "color=yellow",
ControlFlowEdgeType::CaseFallThrough => "style=dotted",
};
writeln!(

View file

@ -69,7 +69,7 @@ impl<I: Iterator<Item = u32>> super::Parser<I> {
}
crate::Function {
name: self.future_decor.remove(&fun_id).and_then(|dec| dec.name),
parameter_types: Vec::with_capacity(ft.parameter_type_ids.len()),
arguments: Vec::with_capacity(ft.parameter_type_ids.len()),
return_type: if self.lookup_void_type.contains(&result_type) {
None
} else {
@ -83,7 +83,7 @@ impl<I: Iterator<Item = u32>> super::Parser<I> {
};
// read parameters
for i in 0..fun.parameter_types.capacity() {
for i in 0..fun.arguments.capacity() {
match self.next_inst()? {
Instruction {
op: spirv::Op::FunctionParameter,
@ -93,7 +93,7 @@ impl<I: Iterator<Item = u32>> super::Parser<I> {
let id = self.next()?;
let handle = fun
.expressions
.append(crate::Expression::FunctionParameter(i as u32));
.append(crate::Expression::FunctionArgument(i as u32));
self.lookup_expression
.insert(id, LookupExpression { type_id, handle });
//Note: we redo the lookup in order to work around `self` borrowing
@ -104,10 +104,11 @@ impl<I: Iterator<Item = u32>> super::Parser<I> {
.lookup(fun_type)?
.parameter_type_ids[i]
{
return Err(Error::WrongFunctionParameterType(type_id));
return Err(Error::WrongFunctionArgumentType(type_id));
}
let ty = self.lookup_type.lookup(type_id)?.handle;
fun.parameter_types.push(ty);
fun.arguments
.push(crate::FunctionArgument { name: None, ty });
}
Instruction { op, .. } => return Err(Error::InvalidParameter(op)),
}
@ -175,6 +176,17 @@ impl<I: Iterator<Item = u32>> super::Parser<I> {
}
};
if let Some(ref prefix) = self.options.flow_graph_dump_prefix {
let dump = flow_graph.to_graphviz().unwrap_or_default();
let suffix = match source {
DeferredSource::EntryPoint(stage, ref name) => {
format!("flow.{:?}-{}.dot", stage, name)
}
DeferredSource::Function(handle) => format!("flow.Fun-{}.dot", handle.index()),
};
let _ = std::fs::write(prefix.join(suffix), dump);
}
for (expr_handle, dst_id) in local_function_calls {
self.deferred_function_calls.push(DeferredFunctionCall {
source: source.clone(),

View file

@ -29,7 +29,7 @@ use crate::{
};
use num_traits::cast::FromPrimitive;
use std::{convert::TryInto, num::NonZeroU32};
use std::{convert::TryInto, num::NonZeroU32, path::PathBuf};
pub const SUPPORTED_CAPABILITIES: &[spirv::Capability] = &[
spirv::Capability::Shader,
@ -304,6 +304,11 @@ pub struct Assignment {
value: Handle<crate::Expression>,
}
#[derive(Clone, Debug, Default)]
pub struct Options {
pub flow_graph_dump_prefix: Option<PathBuf>,
}
pub struct Parser<I> {
data: I,
state: ModuleState,
@ -325,10 +330,11 @@ pub struct Parser<I> {
lookup_function: FastHashMap<spirv::Word, Handle<crate::Function>>,
lookup_entry_point: FastHashMap<spirv::Word, EntryPoint>,
deferred_function_calls: Vec<DeferredFunctionCall>,
options: Options,
}
impl<I: Iterator<Item = u32>> Parser<I> {
pub fn new(data: I) -> Self {
pub fn new(data: I, options: &Options) -> Self {
Parser {
data,
state: ModuleState::Empty,
@ -349,6 +355,7 @@ impl<I: Iterator<Item = u32>> Parser<I> {
lookup_function: FastHashMap::default(),
lookup_entry_point: FastHashMap::default(),
deferred_function_calls: Vec::new(),
options: options.clone(),
}
}
@ -547,8 +554,8 @@ impl<I: Iterator<Item = u32>> Parser<I> {
let init = if inst.wc > 4 {
inst.expect(5)?;
let init_id = self.next()?;
let lexp = self.lookup_expression.lookup(init_id)?;
Some(lexp.handle)
let lconst = self.lookup_constant.lookup(init_id)?;
Some(lconst.handle)
} else {
None
};
@ -852,6 +859,38 @@ impl<I: Iterator<Item = u32>> Parser<I> {
},
);
}
// Bitwise instructions
Op::Not => {
inst.expect(4)?;
self.parse_expr_unary_op(expressions, crate::UnaryOperator::Not)?;
}
Op::BitwiseOr => {
inst.expect(5)?;
self.parse_expr_binary_op(expressions, crate::BinaryOperator::InclusiveOr)?;
}
Op::BitwiseXor => {
inst.expect(5)?;
self.parse_expr_binary_op(expressions, crate::BinaryOperator::ExclusiveOr)?;
}
Op::BitwiseAnd => {
inst.expect(5)?;
self.parse_expr_binary_op(expressions, crate::BinaryOperator::And)?;
}
Op::ShiftRightLogical => {
inst.expect(5)?;
//TODO: convert input and result to usigned
self.parse_expr_binary_op(expressions, crate::BinaryOperator::ShiftRight)?;
}
Op::ShiftRightArithmetic => {
inst.expect(5)?;
//TODO: convert input and result to signed
self.parse_expr_binary_op(expressions, crate::BinaryOperator::ShiftRight)?;
}
Op::ShiftLeftLogical => {
inst.expect(5)?;
self.parse_expr_binary_op(expressions, crate::BinaryOperator::ShiftLeft)?;
}
// Sampling
Op::SampledImage => {
inst.expect(5)?;
let _result_type_id = self.next()?;
@ -1028,6 +1067,31 @@ impl<I: Iterator<Item = u32>> Parser<I> {
},
);
}
Op::Select => {
inst.expect(6)?;
let result_type_id = self.next()?;
let result_id = self.next()?;
let condition = self.next()?;
let o1_id = self.next()?;
let o2_id = self.next()?;
let cond_lexp = self.lookup_expression.lookup(condition)?;
let o1_lexp = self.lookup_expression.lookup(o1_id)?;
let o2_lexp = self.lookup_expression.lookup(o2_id)?;
let expr = crate::Expression::Select {
condition: cond_lexp.handle,
accept: o1_lexp.handle,
reject: o2_lexp.handle,
};
self.lookup_expression.insert(
result_id,
LookupExpression {
handle: expressions.append(expr),
type_id: result_type_id,
},
);
}
Op::VectorShuffle => {
inst.expect_at_least(5)?;
let result_type_id = self.next()?;
@ -1092,11 +1156,10 @@ impl<I: Iterator<Item = u32>> Parser<I> {
let value_lexp = self.lookup_expression.lookup(value_id)?;
let ty_lookup = self.lookup_type.lookup(result_type_id)?;
let kind = match type_arena[ty_lookup.handle].inner {
crate::TypeInner::Scalar { kind, .. }
| crate::TypeInner::Vector { kind, .. } => kind,
_ => return Err(Error::InvalidAsType(ty_lookup.handle)),
};
let kind = type_arena[ty_lookup.handle]
.inner
.scalar_kind()
.ok_or(Error::InvalidAsType(ty_lookup.handle))?;
let expr = crate::Expression::As {
expr: value_lexp.handle,
@ -1220,6 +1283,10 @@ impl<I: Iterator<Item = u32>> Parser<I> {
inst.expect(base_wc + 1)?;
"length"
}
Some(spirv::GLOp::Distance) => {
inst.expect(base_wc + 2)?;
"distance"
}
Some(spirv::GLOp::Cross) => {
inst.expect(base_wc + 2)?;
"cross"
@ -1477,7 +1544,9 @@ impl<I: Iterator<Item = u32>> Parser<I> {
};
*comparison = true;
}
_ => panic!("Unexpected comparison type {:?}", ty),
_ => {
return Err(Error::UnexpectedComparisonType(handle));
}
}
}
@ -1906,12 +1975,13 @@ impl<I: Iterator<Item = u32>> Parser<I> {
inst.expect(4)?;
let id = self.next()?;
let type_id = self.next()?;
let length = self.next()?;
let length_id = self.next()?;
let length_const = self.lookup_constant.lookup(length_id)?;
let decor = self.future_decor.remove(&id);
let inner = crate::TypeInner::Array {
base: self.lookup_type.lookup(type_id)?.handle,
size: crate::ArraySize::Static(length),
size: crate::ArraySize::Constant(length_const.handle),
stride: decor.as_ref().and_then(|dec| dec.array_stride),
};
self.lookup_type.insert(
@ -2030,10 +2100,10 @@ impl<I: Iterator<Item = u32>> Parser<I> {
let format = self.next()?;
let base_handle = self.lookup_type.lookup(sample_type_id)?.handle;
let kind = match module.types[base_handle].inner {
crate::TypeInner::Scalar { kind, .. } | crate::TypeInner::Vector { kind, .. } => kind,
_ => return Err(Error::InvalidImageBaseType(base_handle)),
};
let kind = module.types[base_handle]
.inner
.scalar_kind()
.ok_or(Error::InvalidImageBaseType(base_handle))?;
let class = if format != 0 {
crate::ImageClass::Storage(map_image_format(format)?)
@ -2231,28 +2301,52 @@ impl<I: Iterator<Item = u32>> Parser<I> {
let type_id = self.next()?;
let id = self.next()?;
let storage_class = self.next()?;
if inst.wc != 4 {
let init = if inst.wc > 4 {
inst.expect(5)?;
let _init = self.next()?; //TODO
}
let init_id = self.next()?;
let lconst = self.lookup_constant.lookup(init_id)?;
Some(lconst.handle)
} else {
None
};
let lookup_type = self.lookup_type.lookup(type_id)?;
let dec = self
.future_decor
.remove(&id)
.ok_or(Error::InvalidBinding(id))?;
let class = map_storage_class(storage_class)?;
let class = {
use spirv::StorageClass as Sc;
match Sc::from_u32(storage_class) {
Some(Sc::Function) => crate::StorageClass::Function,
Some(Sc::Input) => crate::StorageClass::Input,
Some(Sc::Output) => crate::StorageClass::Output,
Some(Sc::Private) => crate::StorageClass::Private,
Some(Sc::UniformConstant) => crate::StorageClass::Handle,
Some(Sc::StorageBuffer) => crate::StorageClass::Storage,
Some(Sc::Uniform) => {
if self
.lookup_storage_buffer_types
.contains(&lookup_type.handle)
{
crate::StorageClass::Storage
} else {
crate::StorageClass::Uniform
}
}
Some(Sc::Workgroup) => crate::StorageClass::WorkGroup,
Some(Sc::PushConstant) => crate::StorageClass::PushConstant,
_ => return Err(Error::UnsupportedStorageClass(storage_class)),
}
};
let binding = match (class, &module.types[lookup_type.handle].inner) {
(crate::StorageClass::Input, &crate::TypeInner::Struct { .. })
| (crate::StorageClass::Output, &crate::TypeInner::Struct { .. }) => None,
_ => Some(dec.get_binding().ok_or(Error::InvalidBinding(id))?),
};
let is_storage = match module.types[lookup_type.handle].inner {
crate::TypeInner::Struct { .. } => match class {
crate::StorageClass::StorageBuffer => true,
_ => self
.lookup_storage_buffer_types
.contains(&lookup_type.handle),
},
crate::TypeInner::Struct { .. } => class == crate::StorageClass::Storage,
crate::TypeInner::Image {
class: crate::ImageClass::Storage(_),
..
@ -2278,6 +2372,7 @@ impl<I: Iterator<Item = u32>> Parser<I> {
class,
binding,
ty: lookup_type.handle,
init,
interpolation: dec.interpolation,
storage_access,
};
@ -2292,7 +2387,7 @@ impl<I: Iterator<Item = u32>> Parser<I> {
}
}
pub fn parse_u8_slice(data: &[u8]) -> Result<crate::Module, Error> {
pub fn parse_u8_slice(data: &[u8], options: &Options) -> Result<crate::Module, Error> {
if data.len() % 4 != 0 {
return Err(Error::IncompleteData);
}
@ -2300,7 +2395,7 @@ pub fn parse_u8_slice(data: &[u8]) -> Result<crate::Module, Error> {
let words = data
.chunks(4)
.map(|c| u32::from_le_bytes(c.try_into().unwrap()));
Parser::new(words).parse()
Parser::new(words, options).parse()
}
#[cfg(test)]
@ -2316,6 +2411,6 @@ mod test {
0x0e, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, // GLSL450.
0x01, 0x00, 0x00, 0x00,
];
let _ = super::parse_u8_slice(&bin).unwrap();
let _ = super::parse_u8_slice(&bin, &Default::default()).unwrap();
}
}

View file

@ -9,7 +9,7 @@ fn rosetta_test(file_name: &str) {
let file_path = Path::new(TEST_PATH).join(file_name);
let input = fs::read(&file_path).unwrap();
let module = super::parse_u8_slice(&input).unwrap();
let module = super::parse_u8_slice(&input, &Default::default()).unwrap();
let output = ron::ser::to_string_pretty(&module, Default::default()).unwrap();
let expected = fs::read_to_string(file_path.with_extension("expected.ron")).unwrap();

View file

@ -4,8 +4,9 @@ pub fn map_storage_class(word: &str) -> Result<crate::StorageClass, Error<'_>> {
match word {
"in" => Ok(crate::StorageClass::Input),
"out" => Ok(crate::StorageClass::Output),
"private" => Ok(crate::StorageClass::Private),
"uniform" => Ok(crate::StorageClass::Uniform),
"storage_buffer" => Ok(crate::StorageClass::StorageBuffer),
"storage" => Ok(crate::StorageClass::Storage),
_ => Err(Error::UnknownStorageClass(word)),
}
}

View file

@ -69,12 +69,7 @@ fn consume_token(mut input: &str) -> (Token<'_>, &str) {
if next == Some('=') {
(Token::LogicalOperation(cur), chars.as_str())
} else if next == Some(cur) {
input = chars.as_str();
if chars.next() == Some(cur) {
(Token::ArithmeticShiftOperation(cur), chars.as_str())
} else {
(Token::ShiftOperation(cur), input)
}
(Token::ShiftOperation(cur), chars.as_str())
} else {
(Token::Paren(cur), input)
}

View file

@ -26,7 +26,6 @@ pub enum Token<'a> {
Operation(char),
LogicalOperation(char),
ShiftOperation(char),
ArithmeticShiftOperation(char),
Arrow,
Unknown(char),
UnterminatedString,
@ -37,8 +36,6 @@ pub enum Token<'a> {
pub enum Error<'a> {
#[error("unexpected token: {0:?}")]
Unexpected(Token<'a>),
#[error("constant {0:?} doesn't match its type {1:?}")]
UnexpectedConstantType(crate::ConstantInner, Handle<crate::Type>),
#[error("unable to parse `{0}` as integer: {1}")]
BadInteger(&'a str, std::num::ParseIntError),
#[error("unable to parse `{1}` as float: {1}")]
@ -100,7 +97,7 @@ struct StatementContext<'input, 'temp, 'out> {
types: &'out mut Arena<crate::Type>,
constants: &'out mut Arena<crate::Constant>,
global_vars: &'out Arena<crate::GlobalVariable>,
parameter_types: &'out [Handle<crate::Type>],
arguments: &'out [crate::FunctionArgument],
}
impl<'a> StatementContext<'a, '_, '_> {
@ -113,7 +110,7 @@ impl<'a> StatementContext<'a, '_, '_> {
types: self.types,
constants: self.constants,
global_vars: self.global_vars,
parameter_types: self.parameter_types,
arguments: self.arguments,
}
}
@ -126,7 +123,7 @@ impl<'a> StatementContext<'a, '_, '_> {
constants: self.constants,
global_vars: self.global_vars,
local_vars: self.variables,
parameter_types: self.parameter_types,
arguments: self.arguments,
}
}
}
@ -139,7 +136,7 @@ struct ExpressionContext<'input, 'temp, 'out> {
constants: &'out mut Arena<crate::Constant>,
global_vars: &'out Arena<crate::GlobalVariable>,
local_vars: &'out Arena<crate::LocalVariable>,
parameter_types: &'out [Handle<crate::Type>],
arguments: &'out [crate::FunctionArgument],
}
impl<'a> ExpressionContext<'a, '_, '_> {
@ -152,7 +149,7 @@ impl<'a> ExpressionContext<'a, '_, '_> {
constants: self.constants,
global_vars: self.global_vars,
local_vars: self.local_vars,
parameter_types: self.parameter_types,
arguments: self.arguments,
}
}
@ -166,7 +163,7 @@ impl<'a> ExpressionContext<'a, '_, '_> {
global_vars: self.global_vars,
local_vars: self.local_vars,
functions: &functions,
parameter_types: self.parameter_types,
arguments: self.arguments,
};
match self
.typifier
@ -265,6 +262,7 @@ struct ParsedVariable<'a> {
class: Option<crate::StorageClass>,
ty: Handle<crate::Type>,
access: crate::StorageAccess,
init: Option<Handle<crate::Constant>>,
}
#[derive(Clone, Debug, Error)]
@ -375,9 +373,10 @@ impl Parser {
fn parse_const_expression<'a>(
&mut self,
lexer: &mut Lexer<'a>,
self_ty: Handle<crate::Type>,
type_arena: &mut Arena<crate::Type>,
const_arena: &mut Arena<crate::Constant>,
) -> Result<crate::ConstantInner, Error<'a>> {
) -> Result<Handle<crate::Constant>, Error<'a>> {
self.scopes.push(Scope::ConstantExpr);
let inner = match lexer.peek() {
Token::Word("true") => {
@ -394,7 +393,7 @@ impl Parser {
inner
}
_ => {
let composite_ty = self.parse_type_decl(lexer, None, type_arena)?;
let composite_ty = self.parse_type_decl(lexer, None, type_arena, const_arena)?;
lexer.expect(Token::Paren('('))?;
let mut components = Vec::new();
while !lexer.skip(Token::Paren(')')) {
@ -406,19 +405,21 @@ impl Parser {
composite_ty,
components.len(),
)?;
let inner = self.parse_const_expression(lexer, type_arena, const_arena)?;
components.push(const_arena.fetch_or_append(crate::Constant {
name: None,
specialization: None,
inner,
ty,
}));
let component =
self.parse_const_expression(lexer, ty, type_arena, const_arena)?;
components.push(component);
}
crate::ConstantInner::Composite(components)
}
};
let handle = const_arena.fetch_or_append(crate::Constant {
name: None,
specialization: None,
inner,
ty: self_ty,
});
self.scopes.pop();
Ok(inner)
Ok(handle)
}
fn parse_primary_expression<'a>(
@ -490,7 +491,7 @@ impl Parser {
expr
} else {
*lexer = backup;
let ty = self.parse_type_decl(lexer, None, ctx.types)?;
let ty = self.parse_type_decl(lexer, None, ctx.types, ctx.constants)?;
lexer.expect(Token::Paren('('))?;
let mut components = Vec::new();
while !lexer.skip(Token::Paren(')')) {
@ -790,13 +791,10 @@ impl Parser {
lexer,
|token| match token {
Token::ShiftOperation('<') => {
Some(crate::BinaryOperator::ShiftLeftLogical)
Some(crate::BinaryOperator::ShiftLeft)
}
Token::ShiftOperation('>') => {
Some(crate::BinaryOperator::ShiftRightLogical)
}
Token::ArithmeticShiftOperation('>') => {
Some(crate::BinaryOperator::ShiftRightArithmetic)
Some(crate::BinaryOperator::ShiftRight)
}
_ => None,
},
@ -910,10 +908,11 @@ impl Parser {
&mut self,
lexer: &mut Lexer<'a>,
type_arena: &mut Arena<crate::Type>,
const_arena: &mut Arena<crate::Constant>,
) -> Result<(&'a str, Handle<crate::Type>), Error<'a>> {
let name = lexer.next_ident()?;
lexer.expect(Token::Separator(':'))?;
let ty = self.parse_type_decl(lexer, None, type_arena)?;
let ty = self.parse_type_decl(lexer, None, type_arena, const_arena)?;
Ok((name, ty))
}
@ -932,16 +931,27 @@ impl Parser {
}
let name = lexer.next_ident()?;
lexer.expect(Token::Separator(':'))?;
let ty = self.parse_type_decl(lexer, None, type_arena)?;
let ty = self.parse_type_decl(lexer, None, type_arena, const_arena)?;
let access = match class {
Some(crate::StorageClass::StorageBuffer) => crate::StorageAccess::all(),
Some(crate::StorageClass::Constant) => crate::StorageAccess::LOAD,
Some(crate::StorageClass::Storage) => crate::StorageAccess::all(),
Some(crate::StorageClass::Handle) => {
match type_arena[ty].inner {
//TODO: RW textures
crate::TypeInner::Image {
class: crate::ImageClass::Storage(_),
..
} => crate::StorageAccess::LOAD,
_ => crate::StorageAccess::empty(),
}
}
_ => crate::StorageAccess::empty(),
};
if lexer.skip(Token::Operation('=')) {
let _inner = self.parse_const_expression(lexer, type_arena, const_arena)?;
//TODO
}
let init = if lexer.skip(Token::Operation('=')) {
let handle = self.parse_const_expression(lexer, ty, type_arena, const_arena)?;
Some(handle)
} else {
None
};
lexer.expect(Token::Separator(';'))?;
self.scopes.pop();
Ok(ParsedVariable {
@ -949,6 +959,7 @@ impl Parser {
class,
ty,
access,
init,
})
}
@ -956,6 +967,7 @@ impl Parser {
&mut self,
lexer: &mut Lexer<'a>,
type_arena: &mut Arena<crate::Type>,
const_arena: &mut Arena<crate::Constant>,
) -> Result<Vec<crate::StructMember>, Error<'a>> {
let mut members = Vec::new();
lexer.expect(Token::Paren('{'))?;
@ -992,7 +1004,7 @@ impl Parser {
return Err(Error::MissingMemberOffset(name));
}
lexer.expect(Token::Separator(':'))?;
let ty = self.parse_type_decl(lexer, None, type_arena)?;
let ty = self.parse_type_decl(lexer, None, type_arena, const_arena)?;
lexer.expect(Token::Separator(';'))?;
members.push(crate::StructMember {
name: Some(name.to_owned()),
@ -1007,6 +1019,7 @@ impl Parser {
lexer: &mut Lexer<'a>,
self_name: Option<&'a str>,
type_arena: &mut Arena<crate::Type>,
const_arena: &mut Arena<crate::Constant>,
) -> Result<Handle<crate::Type>, Error<'a>> {
self.scopes.push(Scope::TypeDecl);
let decoration_lexer = if lexer.skip(Token::DoubleParen('[')) {
@ -1128,18 +1141,30 @@ impl Parser {
lexer.expect(Token::Paren('<'))?;
let class = conv::map_storage_class(lexer.next_ident()?)?;
lexer.expect(Token::Separator(','))?;
let base = self.parse_type_decl(lexer, None, type_arena)?;
let base = self.parse_type_decl(lexer, None, type_arena, const_arena)?;
lexer.expect(Token::Paren('>'))?;
crate::TypeInner::Pointer { base, class }
}
Token::Word("array") => {
lexer.expect(Token::Paren('<'))?;
let base = self.parse_type_decl(lexer, None, type_arena)?;
let base = self.parse_type_decl(lexer, None, type_arena, const_arena)?;
let size = match lexer.next() {
Token::Separator(',') => {
let value = lexer.next_uint_literal()?;
lexer.expect(Token::Paren('>'))?;
crate::ArraySize::Static(value)
let const_handle = const_arena.fetch_or_append(crate::Constant {
name: None,
specialization: None,
inner: crate::ConstantInner::Uint(value as u64),
ty: type_arena.fetch_or_append(crate::Type {
name: None,
inner: crate::TypeInner::Scalar {
kind: crate::ScalarKind::Uint,
width: 4,
},
}),
});
crate::ArraySize::Constant(const_handle)
}
Token::Paren('>') => crate::ArraySize::Dynamic,
other => return Err(Error::Unexpected(other)),
@ -1167,7 +1192,7 @@ impl Parser {
crate::TypeInner::Array { base, size, stride }
}
Token::Word("struct") => {
let members = self.parse_struct_body(lexer, type_arena)?;
let members = self.parse_struct_body(lexer, type_arena, const_arena)?;
crate::TypeInner::Struct { members }
}
Token::Word("sampler") => crate::TypeInner::Sampler { comparison: false },
@ -1368,15 +1393,20 @@ impl Parser {
"var" => {
enum Init {
Empty,
Uniform(Handle<crate::Expression>),
Constant(Handle<crate::Constant>),
Variable(Handle<crate::Expression>),
}
let (name, ty) = self.parse_variable_ident_decl(lexer, context.types)?;
let (name, ty) = self.parse_variable_ident_decl(
lexer,
context.types,
context.constants,
)?;
let init = if lexer.skip(Token::Operation('=')) {
let value =
self.parse_general_expression(lexer, context.as_expression())?;
if let crate::Expression::Constant(_) = context.expressions[value] {
Init::Uniform(value)
if let crate::Expression::Constant(handle) = context.expressions[value]
{
Init::Constant(handle)
} else {
Init::Variable(value)
}
@ -1388,7 +1418,7 @@ impl Parser {
name: Some(name.to_owned()),
ty,
init: match init {
Init::Uniform(value) => Some(value),
Init::Constant(value) => Some(value),
_ => None,
},
});
@ -1515,31 +1545,34 @@ impl Parser {
lookup_ident.insert(name, expr_handle);
}
// read parameter list
let mut parameter_types = Vec::new();
let mut arguments = Vec::new();
lexer.expect(Token::Paren('('))?;
while !lexer.skip(Token::Paren(')')) {
if !parameter_types.is_empty() {
if !arguments.is_empty() {
lexer.expect(Token::Separator(','))?;
}
let (param_name, param_type) =
self.parse_variable_ident_decl(lexer, &mut module.types)?;
let param_index = parameter_types.len() as u32;
self.parse_variable_ident_decl(lexer, &mut module.types, &mut module.constants)?;
let param_index = arguments.len() as u32;
let expression_token =
expressions.append(crate::Expression::FunctionParameter(param_index));
expressions.append(crate::Expression::FunctionArgument(param_index));
lookup_ident.insert(param_name, expression_token);
parameter_types.push(param_type);
arguments.push(crate::FunctionArgument {
name: Some(param_name.to_string()),
ty: param_type,
});
}
// read return type
lexer.expect(Token::Arrow)?;
let return_type = if lexer.skip(Token::Word("void")) {
None
} else {
Some(self.parse_type_decl(lexer, None, &mut module.types)?)
Some(self.parse_type_decl(lexer, None, &mut module.types, &mut module.constants)?)
};
let mut fun = crate::Function {
name: Some(fun_name.to_string()),
parameter_types,
arguments,
return_type,
global_usage: Vec::new(),
local_variables: Arena::new(),
@ -1559,7 +1592,7 @@ impl Parser {
types: &mut module.types,
constants: &mut module.constants,
global_vars: &module.global_variables,
parameter_types: &fun.parameter_types,
arguments: &fun.arguments,
},
)?;
// done
@ -1680,25 +1713,29 @@ impl Parser {
Token::Word("type") => {
let name = lexer.next_ident()?;
lexer.expect(Token::Operation('='))?;
let ty = self.parse_type_decl(lexer, Some(name), &mut module.types)?;
let ty = self.parse_type_decl(
lexer,
Some(name),
&mut module.types,
&mut module.constants,
)?;
self.lookup_type.insert(name.to_owned(), ty);
lexer.expect(Token::Separator(';'))?;
}
Token::Word("const") => {
let (name, ty) = self.parse_variable_ident_decl(lexer, &mut module.types)?;
let (name, ty) = self.parse_variable_ident_decl(
lexer,
&mut module.types,
&mut module.constants,
)?;
lexer.expect(Token::Operation('='))?;
let inner =
self.parse_const_expression(lexer, &mut module.types, &mut module.constants)?;
lexer.expect(Token::Separator(';'))?;
if !crate::proc::check_constant_type(&inner, &module.types[ty].inner) {
return Err(Error::UnexpectedConstantType(inner, ty));
}
let const_handle = module.constants.append(crate::Constant {
name: Some(name.to_owned()),
specialization: None,
inner,
let const_handle = self.parse_const_expression(
lexer,
ty,
});
&mut module.types,
&mut module.constants,
)?;
lexer.expect(Token::Separator(';'))?;
lookup_global_expression.insert(name, crate::Expression::Constant(const_handle));
}
Token::Word("var") => {
@ -1712,7 +1749,7 @@ impl Parser {
crate::BuiltIn::Position => crate::StorageClass::Output,
_ => unimplemented!(),
},
_ => crate::StorageClass::Private,
_ => crate::StorageClass::Handle,
},
};
let var_handle = module.global_variables.append(crate::GlobalVariable {
@ -1720,6 +1757,7 @@ impl Parser {
class,
binding: binding.take(),
ty: pvar.ty,
init: pvar.init,
interpolation,
storage_access: pvar.access,
});
@ -1787,7 +1825,13 @@ impl Parser {
}
Ok(true) => {}
Ok(false) => {
assert_eq!(self.scopes, Vec::new());
if !self.scopes.is_empty() {
return Err(ParseError {
error: Error::Other,
scopes: std::mem::replace(&mut self.scopes, Vec::new()),
pos: (0, 0),
});
};
return Ok(module);
}
}
@ -1802,5 +1846,5 @@ pub fn parse_str(source: &str) -> Result<crate::Module, ParseError> {
#[test]
fn parse_types() {
assert!(parse_str("const a : i32 = 2;").is_ok());
assert!(parse_str("const a : i32 = 2.0;").is_err());
assert!(parse_str("const a : x32 = 2;").is_err());
}

View file

@ -4,7 +4,11 @@
//!
//! To improve performance and reduce memory usage, most structures are stored
//! in an [`Arena`], and can be retrieved using the corresponding [`Handle`].
#![allow(clippy::new_without_default, clippy::unneeded_field_pattern)]
#![allow(
clippy::new_without_default,
clippy::unneeded_field_pattern,
clippy::match_like_matches_macro
)]
#![deny(clippy::panic)]
mod arena;
@ -57,7 +61,7 @@ pub struct Header {
/// For more, see:
/// - https://www.khronos.org/opengl/wiki/Early_Fragment_Test#Explicit_specification
/// - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-attributes-earlydepthstencil
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub struct EarlyDepthTest {
@ -73,7 +77,7 @@ pub struct EarlyDepthTest {
/// For more, see:
/// - https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_conservative_depth.txt
/// - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-semantics#system-value-semantics
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum ConservativeDepth {
@ -88,7 +92,7 @@ pub enum ConservativeDepth {
}
/// Stage of the programmable pipeline.
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[allow(missing_docs)] // The names are self evident
@ -99,23 +103,33 @@ pub enum ShaderStage {
}
/// Class of storage for variables.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[allow(missing_docs)] // The names are self evident
pub enum StorageClass {
Constant,
/// Function locals.
Function,
/// Pipeline input, per invocation.
Input,
/// Pipeline output, per invocation, mutable.
Output,
/// Private data, per invocation, mutable.
Private,
StorageBuffer,
Uniform,
/// Workgroup shared data, mutable.
WorkGroup,
/// Uniform buffer data.
Uniform,
/// Storage buffer data, potentially mutable.
Storage,
/// Opaque handles, such as samplers and images.
Handle,
/// Push constants.
PushConstant,
}
/// Built-in inputs and outputs.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum BuiltIn {
@ -144,7 +158,7 @@ pub type Bytes = u8;
/// Number of components in a vector.
#[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum VectorSize {
@ -158,7 +172,7 @@ pub enum VectorSize {
/// Primitive type for a scalar.
#[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum ScalarKind {
@ -174,18 +188,18 @@ pub enum ScalarKind {
/// Size of an array.
#[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum ArraySize {
/// The array size is known at compilation.
Static(u32),
/// The array size is constant.
Constant(Handle<Constant>),
/// The array size can change at runtime.
Dynamic,
}
/// Describes where a struct member is placed.
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum MemberOrigin {
@ -198,7 +212,7 @@ pub enum MemberOrigin {
}
/// The interpolation qualifier of a binding or struct field.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum Interpolation {
@ -233,7 +247,7 @@ pub struct StructMember {
}
/// The number of dimensions an image has.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum ImageDimension {
@ -260,7 +274,7 @@ bitflags::bitflags! {
}
// Storage image format.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum StorageFormat {
@ -310,7 +324,7 @@ pub enum StorageFormat {
}
/// Sub-class of the image type.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum ImageClass {
@ -392,8 +406,7 @@ pub struct Constant {
}
/// Additional information, dependendent on the kind of constant.
// Clone is used only for error reporting and is not intended for end users
#[derive(Clone, Debug, PartialEq)]
#[derive(Debug, PartialEq)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum ConstantInner {
@ -442,6 +455,8 @@ pub struct GlobalVariable {
pub binding: Option<Binding>,
/// The type of this variable.
pub ty: Handle<Type>,
/// Initial value for this variable.
pub init: Option<Handle<Constant>>,
/// The interpolation qualifier, if any.
/// If the this `GlobalVariable` is a vertex output
/// or fragment input, `None` corresponds to the
@ -461,11 +476,11 @@ pub struct LocalVariable {
/// The type of this variable.
pub ty: Handle<Type>,
/// Initial value for this variable.
pub init: Option<Handle<Expression>>,
pub init: Option<Handle<Constant>>,
}
/// Operation that can be applied on a single value.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum UnaryOperator {
@ -474,7 +489,7 @@ pub enum UnaryOperator {
}
/// Operation that can be applied on two values.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum BinaryOperator {
@ -494,13 +509,13 @@ pub enum BinaryOperator {
InclusiveOr,
LogicalAnd,
LogicalOr,
ShiftLeftLogical,
ShiftRightLogical,
ShiftRightArithmetic,
ShiftLeft,
/// Right shift carries the sign of signed integers only.
ShiftRight,
}
/// Built-in shader function.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum IntrinsicFunction {
@ -513,7 +528,7 @@ pub enum IntrinsicFunction {
}
/// Axis on which to compute a derivative.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum DerivativeAxis {
@ -569,7 +584,7 @@ pub enum Expression {
components: Vec<Handle<Expression>>,
},
/// Reference a function parameter, by its index.
FunctionParameter(u32),
FunctionArgument(u32),
/// Reference a global variable.
GlobalVariable(Handle<GlobalVariable>),
/// Reference a local variable.
@ -604,6 +619,13 @@ pub enum Expression {
left: Handle<Expression>,
right: Handle<Expression>,
},
/// Select between two values based on a condition.
Select {
/// Boolean expression
condition: Handle<Expression>,
accept: Handle<Expression>,
reject: Handle<Expression>,
},
/// Call an intrinsic function.
Intrinsic {
fun: IntrinsicFunction,
@ -687,6 +709,17 @@ pub enum Statement {
},
}
/// A function argument.
#[derive(Debug)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub struct FunctionArgument {
/// Name of the argument, if any.
pub name: Option<String>,
/// Type of the argument.
pub ty: Handle<Type>,
}
/// A function defined in the module.
#[derive(Debug)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
@ -694,9 +727,8 @@ pub enum Statement {
pub struct Function {
/// Name of the function, if any.
pub name: Option<String>,
//pub control: spirv::FunctionControl,
/// The types of the parameters of this function.
pub parameter_types: Vec<Handle<Type>>,
/// Information about function argument.
pub arguments: Vec<FunctionArgument>,
/// The return type of this function, if any.
pub return_type: Option<Handle<Type>>,
/// Vector of global variable usages.

View file

@ -2,6 +2,7 @@ use crate::arena::{Arena, Handle};
pub struct Interface<'a, T> {
pub expressions: &'a Arena<crate::Expression>,
pub local_variables: &'a Arena<crate::LocalVariable>,
pub visitor: T,
}
@ -36,7 +37,7 @@ where
self.traverse_expr(comp);
}
}
E::FunctionParameter(_) | E::GlobalVariable(_) | E::LocalVariable(_) => {}
E::FunctionArgument(_) | E::GlobalVariable(_) | E::LocalVariable(_) => {}
E::Load { pointer } => {
self.traverse_expr(pointer);
}
@ -78,6 +79,15 @@ where
self.traverse_expr(left);
self.traverse_expr(right);
}
E::Select {
condition,
accept,
reject,
} => {
self.traverse_expr(condition);
self.traverse_expr(accept);
self.traverse_expr(reject);
}
E::Intrinsic { argument, .. } => {
self.traverse_expr(argument);
}
@ -201,6 +211,7 @@ impl crate::Function {
let mut io = Interface {
expressions: &self.expressions,
local_variables: &self.local_variables,
visitor: GlobalUseVisitor(&mut self.global_usage),
};
io.traverse(&self.body);
@ -218,9 +229,10 @@ mod tests {
fn global_use_scan() {
let test_global = GlobalVariable {
name: None,
class: StorageClass::Constant,
class: StorageClass::Uniform,
binding: None,
ty: Handle::new(std::num::NonZeroU32::new(1).unwrap()),
init: None,
interpolation: None,
storage_access: StorageAccess::empty(),
};
@ -256,7 +268,7 @@ mod tests {
let mut function = crate::Function {
name: None,
parameter_types: Vec::new(),
arguments: Vec::new(),
return_type: None,
local_variables: Arena::new(),
expressions,

View file

@ -1,10 +1,12 @@
//! Module processing functionality.
mod interface;
mod namer;
mod typifier;
mod validator;
pub use interface::{Interface, Visitor};
pub use namer::{EntryPointIndex, NameKey, Namer};
pub use typifier::{check_constant_type, ResolveContext, ResolveError, Typifier};
pub use validator::{ValidationError, Validator};
@ -47,3 +49,15 @@ impl From<super::StorageFormat> for super::ScalarKind {
}
}
}
impl crate::TypeInner {
pub fn scalar_kind(&self) -> Option<super::ScalarKind> {
match *self {
super::TypeInner::Scalar { kind, .. } | super::TypeInner::Vector { kind, .. } => {
Some(kind)
}
super::TypeInner::Matrix { .. } => Some(super::ScalarKind::Float),
_ => None,
}
}
}

113
third_party/rust/naga/src/proc/namer.rs vendored Normal file
View file

@ -0,0 +1,113 @@
use crate::{arena::Handle, FastHashMap};
use std::collections::hash_map::Entry;
pub type EntryPointIndex = u16;
#[derive(Debug, Eq, Hash, PartialEq)]
pub enum NameKey {
GlobalVariable(Handle<crate::GlobalVariable>),
Type(Handle<crate::Type>),
StructMember(Handle<crate::Type>, u32),
Function(Handle<crate::Function>),
FunctionArgument(Handle<crate::Function>, u32),
FunctionLocal(Handle<crate::Function>, Handle<crate::LocalVariable>),
EntryPoint(EntryPointIndex),
EntryPointLocal(EntryPointIndex, Handle<crate::LocalVariable>),
}
/// This processor assigns names to all the things in a module
/// that may need identifiers in a textual backend.
pub struct Namer {
unique: FastHashMap<String, u32>,
}
impl Namer {
fn sanitize(string: &str) -> String {
let mut base = string
.chars()
.skip_while(|c| c.is_numeric())
.filter(|&c| c.is_ascii_alphanumeric() || c == '_')
.collect::<String>();
// close the name by '_' if the re is a number, so that
// we can have our own number!
match base.chars().next_back() {
Some(c) if !c.is_numeric() => {}
_ => base.push('_'),
};
base
}
fn call(&mut self, label_raw: &str) -> String {
let base = Self::sanitize(label_raw);
match self.unique.entry(base) {
Entry::Occupied(mut e) => {
*e.get_mut() += 1;
format!("{}{}", e.key(), e.get())
}
Entry::Vacant(e) => {
let name = e.key().to_string();
e.insert(0);
name
}
}
}
fn call_or(&mut self, label: &Option<String>, fallback: &str) -> String {
self.call(match *label {
Some(ref name) => name,
None => fallback,
})
}
pub fn process(
module: &crate::Module,
reserved: &[&str],
output: &mut FastHashMap<NameKey, String>,
) {
let mut this = Namer {
unique: reserved
.iter()
.map(|string| (string.to_string(), 0))
.collect(),
};
for (handle, var) in module.global_variables.iter() {
let name = this.call_or(&var.name, "global");
output.insert(NameKey::GlobalVariable(handle), name);
}
for (ty_handle, ty) in module.types.iter() {
let ty_name = this.call_or(&ty.name, "type");
output.insert(NameKey::Type(ty_handle), ty_name);
if let crate::TypeInner::Struct { ref members } = ty.inner {
for (index, member) in members.iter().enumerate() {
let name = this.call_or(&member.name, "member");
output.insert(NameKey::StructMember(ty_handle, index as u32), name);
}
}
}
for (fun_handle, fun) in module.functions.iter() {
let fun_name = this.call_or(&fun.name, "function");
output.insert(NameKey::Function(fun_handle), fun_name);
for (index, arg) in fun.arguments.iter().enumerate() {
let name = this.call_or(&arg.name, "param");
output.insert(NameKey::FunctionArgument(fun_handle, index as u32), name);
}
for (handle, var) in fun.local_variables.iter() {
let name = this.call_or(&var.name, "local");
output.insert(NameKey::FunctionLocal(fun_handle, handle), name);
}
}
for (ep_index, (&(_, ref base_name), ep)) in module.entry_points.iter().enumerate() {
let ep_name = this.call(base_name);
output.insert(NameKey::EntryPoint(ep_index as _), ep_name);
for (handle, var) in ep.function.local_variables.iter() {
let name = this.call_or(&var.name, "local");
output.insert(NameKey::EntryPointLocal(ep_index as _, handle), name);
}
}
}
}

View file

@ -29,6 +29,7 @@ impl Clone for Resolution {
columns,
width,
},
#[allow(clippy::panic)]
_ => panic!("Unepxected clone type: {:?}", v),
}),
}
@ -50,6 +51,14 @@ pub enum ResolveError {
FunctionReturnsVoid,
#[error("Type is not found in the given immutable arena")]
TypeNotFound,
#[error("Incompatible operand: {op} {operand}")]
IncompatibleOperand { op: String, operand: String },
#[error("Incompatible operands: {left} {op} {right}")]
IncompatibleOperands {
op: String,
left: String,
right: String,
},
}
pub struct ResolveContext<'a> {
@ -57,7 +66,7 @@ pub struct ResolveContext<'a> {
pub global_vars: &'a Arena<crate::GlobalVariable>,
pub local_vars: &'a Arena<crate::LocalVariable>,
pub functions: &'a Arena<crate::Function>,
pub parameter_types: &'a [Handle<crate::Type>],
pub arguments: &'a [crate::FunctionArgument],
}
impl Typifier {
@ -82,6 +91,16 @@ impl Typifier {
}
}
pub fn get_handle(
&self,
expr_handle: Handle<crate::Expression>,
) -> Option<Handle<crate::Type>> {
match self.resolutions[expr_handle.index()] {
Resolution::Handle(ty_handle) => Some(ty_handle),
Resolution::Value(_) => None,
}
}
fn resolve_impl(
&self,
expr: &crate::Expression,
@ -105,7 +124,12 @@ impl Typifier {
kind: crate::ScalarKind::Float,
width,
}),
ref other => panic!("Can't access into {:?}", other),
ref other => {
return Err(ResolveError::IncompatibleOperand {
op: "access".to_string(),
operand: format!("{:?}", other),
})
}
},
crate::Expression::AccessIndex { base, index } => match *self.get(base, types) {
crate::TypeInner::Vector { size, kind, width } => {
@ -135,12 +159,17 @@ impl Typifier {
.ok_or(ResolveError::InvalidAccessIndex)?;
Resolution::Handle(member.ty)
}
ref other => panic!("Can't access into {:?}", other),
ref other => {
return Err(ResolveError::IncompatibleOperand {
op: "access index".to_string(),
operand: format!("{:?}", other),
})
}
},
crate::Expression::Constant(h) => Resolution::Handle(ctx.constants[h].ty),
crate::Expression::Compose { ty, .. } => Resolution::Handle(ty),
crate::Expression::FunctionParameter(index) => {
Resolution::Handle(ctx.parameter_types[index as usize])
crate::Expression::FunctionArgument(index) => {
Resolution::Handle(ctx.arguments[index as usize].ty)
}
crate::Expression::GlobalVariable(h) => Resolution::Handle(ctx.global_vars[h].ty),
crate::Expression::LocalVariable(h) => Resolution::Handle(ctx.local_vars[h].ty),
@ -192,7 +221,13 @@ impl Typifier {
kind: crate::ScalarKind::Float,
width,
}),
_ => panic!("Incompatible arguments {:?} x {:?}", ty_left, ty_right),
_ => {
return Err(ResolveError::IncompatibleOperands {
op: "x".to_string(),
left: format!("{:?}", ty_left),
right: format!("{:?}", ty_right),
})
}
}
}
}
@ -207,12 +242,10 @@ impl Typifier {
crate::BinaryOperator::And
| crate::BinaryOperator::ExclusiveOr
| crate::BinaryOperator::InclusiveOr
| crate::BinaryOperator::ShiftLeftLogical
| crate::BinaryOperator::ShiftRightLogical
| crate::BinaryOperator::ShiftRightArithmetic => {
self.resolutions[left.index()].clone()
}
| crate::BinaryOperator::ShiftLeft
| crate::BinaryOperator::ShiftRight => self.resolutions[left.index()].clone(),
},
crate::Expression::Select { accept, .. } => self.resolutions[accept.index()].clone(),
crate::Expression::Intrinsic { .. } => unimplemented!(),
crate::Expression::Transpose(expr) => match *self.get(expr, types) {
crate::TypeInner::Matrix {
@ -224,7 +257,12 @@ impl Typifier {
rows: columns,
width,
}),
ref other => panic!("incompatible transpose of {:?}", other),
ref other => {
return Err(ResolveError::IncompatibleOperand {
op: "transpose".to_string(),
operand: format!("{:?}", other),
})
}
},
crate::Expression::DotProduct(left_expr, _) => match *self.get(left_expr, types) {
crate::TypeInner::Vector {
@ -232,7 +270,12 @@ impl Typifier {
size: _,
width,
} => Resolution::Value(crate::TypeInner::Scalar { kind, width }),
ref other => panic!("incompatible dot of {:?}", other),
ref other => {
return Err(ResolveError::IncompatibleOperand {
op: "dot product".to_string(),
operand: format!("{:?}", other),
})
}
},
crate::Expression::CrossProduct(_, _) => unimplemented!(),
crate::Expression::As {
@ -248,7 +291,12 @@ impl Typifier {
size,
width,
} => Resolution::Value(crate::TypeInner::Vector { kind, size, width }),
ref other => panic!("incompatible as of {:?}", other),
ref other => {
return Err(ResolveError::IncompatibleOperand {
op: "as".to_string(),
operand: format!("{:?}", other),
})
}
},
crate::Expression::Derivative { .. } => unimplemented!(),
crate::Expression::Call {
@ -260,13 +308,23 @@ impl Typifier {
| crate::TypeInner::Scalar { kind, width } => {
Resolution::Value(crate::TypeInner::Scalar { kind, width })
}
ref other => panic!("Unexpected argument {:?} on {}", other, name),
ref other => {
return Err(ResolveError::IncompatibleOperand {
op: name.clone(),
operand: format!("{:?}", other),
})
}
},
"dot" => match *self.get(arguments[0], types) {
crate::TypeInner::Vector { kind, width, .. } => {
Resolution::Value(crate::TypeInner::Scalar { kind, width })
}
ref other => panic!("Unexpected argument {:?} on {}", other, name),
ref other => {
return Err(ResolveError::IncompatibleOperand {
op: name.clone(),
operand: format!("{:?}", other),
})
}
},
//Note: `cross` is here too, we still need to figure out what to do with it
"abs" | "atan2" | "cos" | "sin" | "floor" | "inverse" | "normalize" | "min"

View file

@ -21,20 +21,37 @@ pub enum GlobalVariableError {
InvalidType,
#[error("Interpolation is not valid")]
InvalidInterpolation,
#[error("Storage access flags are invalid")]
InvalidStorageAccess,
#[error("Storage access {seen:?} exceed the allowed {allowed:?}")]
InvalidStorageAccess {
allowed: crate::StorageAccess,
seen: crate::StorageAccess,
},
#[error("Binding decoration is missing or not applicable")]
InvalidBinding,
#[error("Binding is out of range")]
OutOfRangeBinding,
}
#[derive(Clone, Debug, thiserror::Error)]
pub enum LocalVariableError {
#[error("Initializer is not a constant expression")]
InitializerConst,
#[error("Initializer doesn't match the variable type")]
InitializerType,
}
#[derive(Clone, Debug, thiserror::Error)]
pub enum FunctionError {
#[error(transparent)]
Resolve(#[from] ResolveError),
#[error("There are instructions after `return`/`break`/`continue`")]
InvalidControlFlowExitTail,
#[error("Local variable {handle:?} '{name}' is invalid: {error:?}")]
LocalVariable {
handle: Handle<crate::LocalVariable>,
name: String,
error: LocalVariableError,
},
}
#[derive(Clone, Debug, thiserror::Error)]
@ -63,8 +80,14 @@ pub enum ValidationError {
InvalidTypeWidth(crate::ScalarKind, crate::Bytes),
#[error("The type handle {0:?} can not be resolved")]
UnresolvedType(Handle<crate::Type>),
#[error("Global variable {0:?} is invalid: {1:?}")]
GlobalVariable(Handle<crate::GlobalVariable>, GlobalVariableError),
#[error("The constant {0:?} can not be used for an array size")]
InvalidArraySizeConstant(Handle<crate::Constant>),
#[error("Global variable {handle:?} '{name}' is invalid: {error:?}")]
GlobalVariable {
handle: Handle<crate::GlobalVariable>,
name: String,
error: GlobalVariableError,
},
#[error("Function {0:?} is invalid: {1:?}")]
Function(Handle<crate::Function>, FunctionError),
#[error("Entry point {name} at {stage:?} is invalid: {error:?}")]
@ -73,6 +96,43 @@ pub enum ValidationError {
name: String,
error: EntryPointError,
},
#[error("Module is corrupted")]
Corrupted,
}
impl crate::GlobalVariable {
fn forbid_interpolation(&self) -> Result<(), GlobalVariableError> {
match self.interpolation {
Some(_) => Err(GlobalVariableError::InvalidInterpolation),
None => Ok(()),
}
}
fn check_resource(&self) -> Result<(), GlobalVariableError> {
match self.binding {
Some(crate::Binding::BuiltIn(_)) => {} // validated per entry point
Some(crate::Binding::Resource { group, binding }) => {
if group > MAX_BIND_GROUPS || binding > MAX_BIND_INDICES {
return Err(GlobalVariableError::OutOfRangeBinding);
}
}
Some(crate::Binding::Location(_)) | None => {
return Err(GlobalVariableError::InvalidBinding)
}
}
self.forbid_interpolation()
}
}
fn storage_usage(access: crate::StorageAccess) -> crate::GlobalUse {
let mut storage_usage = crate::GlobalUse::empty();
if access.contains(crate::StorageAccess::LOAD) {
storage_usage |= crate::GlobalUse::LOAD;
}
if access.contains(crate::StorageAccess::STORE) {
storage_usage |= crate::GlobalUse::STORE;
}
storage_usage
}
impl Validator {
@ -89,15 +149,13 @@ impl Validator {
types: &Arena<crate::Type>,
) -> Result<(), GlobalVariableError> {
log::debug!("var {:?}", var);
let is_storage = match var.class {
let allowed_storage_access = match var.class {
crate::StorageClass::Function => return Err(GlobalVariableError::InvalidUsage),
crate::StorageClass::Input | crate::StorageClass::Output => {
match var.binding {
Some(crate::Binding::BuiltIn(_)) => {
// validated per entry point
if var.interpolation.is_some() {
return Err(GlobalVariableError::InvalidInterpolation);
}
var.forbid_interpolation()?
}
Some(crate::Binding::Location(loc)) => {
if loc > MAX_LOCATIONS {
@ -117,61 +175,73 @@ impl Validator {
match types[var.ty].inner {
//TODO: check the member types
crate::TypeInner::Struct { members: _ } => {
if var.interpolation.is_some() {
return Err(GlobalVariableError::InvalidInterpolation);
}
var.forbid_interpolation()?
}
_ => return Err(GlobalVariableError::InvalidType),
}
}
}
false
crate::StorageAccess::empty()
}
crate::StorageClass::Constant
| crate::StorageClass::StorageBuffer
| crate::StorageClass::Uniform => {
match var.binding {
Some(crate::Binding::BuiltIn(_)) => {} // validated per entry point
Some(crate::Binding::Resource { group, binding }) => {
if group > MAX_BIND_GROUPS || binding > MAX_BIND_INDICES {
return Err(GlobalVariableError::OutOfRangeBinding);
}
}
Some(crate::Binding::Location(_)) | None => {
return Err(GlobalVariableError::InvalidBinding)
}
}
if var.interpolation.is_some() {
return Err(GlobalVariableError::InvalidInterpolation);
}
//TODO: prevent `Uniform` storage class with `STORE` access
crate::StorageClass::Storage => {
var.check_resource()?;
crate::StorageAccess::all()
}
crate::StorageClass::Uniform => {
var.check_resource()?;
crate::StorageAccess::empty()
}
crate::StorageClass::Handle => {
var.check_resource()?;
match types[var.ty].inner {
crate::TypeInner::Struct { .. }
| crate::TypeInner::Image {
crate::TypeInner::Image {
class: crate::ImageClass::Storage(_),
..
} => true,
_ => false,
} => crate::StorageAccess::all(),
_ => crate::StorageAccess::empty(),
}
}
crate::StorageClass::Private | crate::StorageClass::WorkGroup => {
if var.binding.is_some() {
return Err(GlobalVariableError::InvalidBinding);
}
if var.interpolation.is_some() {
return Err(GlobalVariableError::InvalidInterpolation);
}
false
var.forbid_interpolation()?;
crate::StorageAccess::empty()
}
crate::StorageClass::PushConstant => {
//TODO
return Err(GlobalVariableError::InvalidStorageAccess {
allowed: crate::StorageAccess::empty(),
seen: crate::StorageAccess::empty(),
});
}
};
if !is_storage && !var.storage_access.is_empty() {
return Err(GlobalVariableError::InvalidStorageAccess);
if !allowed_storage_access.contains(var.storage_access) {
return Err(GlobalVariableError::InvalidStorageAccess {
allowed: allowed_storage_access,
seen: var.storage_access,
});
}
Ok(())
}
fn validate_local_var(
&self,
var: &crate::LocalVariable,
_fun: &crate::Function,
_types: &Arena<crate::Type>,
) -> Result<(), LocalVariableError> {
log::debug!("var {:?}", var);
if let Some(_expr_handle) = var.init {
if false {
return Err(LocalVariableError::InitializerConst);
}
}
Ok(())
}
fn validate_function(
&mut self,
fun: &crate::Function,
@ -182,10 +252,19 @@ impl Validator {
global_vars: &module.global_variables,
local_vars: &fun.local_variables,
functions: &module.functions,
parameter_types: &fun.parameter_types,
arguments: &fun.arguments,
};
self.typifier
.resolve_all(&fun.expressions, &module.types, &resolve_ctx)?;
for (var_handle, var) in fun.local_variables.iter() {
self.validate_local_var(var, fun, &module.types)
.map_err(|error| FunctionError::LocalVariable {
handle: var_handle,
name: var.name.clone().unwrap_or_default(),
error,
})?;
}
Ok(())
}
@ -226,17 +305,12 @@ impl Validator {
match (stage, var.class) {
(crate::ShaderStage::Vertex, crate::StorageClass::Output)
| (crate::ShaderStage::Fragment, crate::StorageClass::Input) => {
match module.types[var.ty].inner {
crate::TypeInner::Scalar { kind, .. }
| crate::TypeInner::Vector { kind, .. } => {
if kind != crate::ScalarKind::Float
&& var.interpolation != Some(crate::Interpolation::Flat)
{
return Err(EntryPointError::InvalidIntegerInterpolation);
}
match module.types[var.ty].inner.scalar_kind() {
Some(crate::ScalarKind::Float) => {}
Some(_) if var.interpolation != Some(crate::Interpolation::Flat) => {
return Err(EntryPointError::InvalidIntegerInterpolation);
}
crate::TypeInner::Matrix { .. } => {}
_ => unreachable!(),
_ => {}
}
}
_ => {}
@ -291,26 +365,19 @@ impl Validator {
location_out_mask |= mask;
crate::GlobalUse::LOAD | crate::GlobalUse::STORE
}
crate::StorageClass::Constant => crate::GlobalUse::LOAD,
crate::StorageClass::Uniform | crate::StorageClass::StorageBuffer => {
//TODO: built-in checks?
let mut storage_usage = crate::GlobalUse::empty();
if var.storage_access.contains(crate::StorageAccess::LOAD) {
storage_usage |= crate::GlobalUse::LOAD;
}
if var.storage_access.contains(crate::StorageAccess::STORE) {
storage_usage |= crate::GlobalUse::STORE;
}
if storage_usage.is_empty() {
// its a uniform buffer
crate::GlobalUse::LOAD
} else {
storage_usage
}
}
crate::StorageClass::Uniform => crate::GlobalUse::LOAD,
crate::StorageClass::Storage => storage_usage(var.storage_access),
crate::StorageClass::Handle => match module.types[var.ty].inner {
crate::TypeInner::Image {
class: crate::ImageClass::Storage(_),
..
} => storage_usage(var.storage_access),
_ => crate::GlobalUse::LOAD,
},
crate::StorageClass::Private | crate::StorageClass::WorkGroup => {
crate::GlobalUse::all()
}
crate::StorageClass::PushConstant => crate::GlobalUse::LOAD,
};
if !allowed_usage.contains(usage) {
log::warn!("\tUsage error for: {:?}", var);
@ -364,10 +431,22 @@ impl Validator {
return Err(ValidationError::UnresolvedType(base));
}
}
Ti::Array { base, .. } => {
Ti::Array { base, size, .. } => {
if base >= handle {
return Err(ValidationError::UnresolvedType(base));
}
if let crate::ArraySize::Constant(const_handle) = size {
let constant = module
.constants
.try_get(const_handle)
.ok_or(ValidationError::Corrupted)?;
match constant.inner {
crate::ConstantInner::Uint(_) => {}
_ => {
return Err(ValidationError::InvalidArraySizeConstant(const_handle))
}
}
}
}
Ti::Struct { ref members } => {
//TODO: check that offsets are not intersecting?
@ -384,7 +463,11 @@ impl Validator {
for (var_handle, var) in module.global_variables.iter() {
self.validate_global_var(var, &module.types)
.map_err(|e| ValidationError::GlobalVariable(var_handle, e))?;
.map_err(|error| ValidationError::GlobalVariable {
handle: var_handle,
name: var.name.clone().unwrap_or_default(),
error,
})?;
}
for (fun_handle, fun) in module.functions.iter() {

View file

@ -1,7 +1,8 @@
(
spv_flow_dump_prefix: "",
metal_bindings: {
(group: 0, binding: 0): (buffer: Some(0), mutable: false),
(group: 0, binding: 1): (buffer: Some(1), mutable: true),
(group: 0, binding: 2): (buffer: Some(2), mutable: true),
(stage: Compute, group: 0, binding: 0): (buffer: Some(0), mutable: false),
(stage: Compute, group: 0, binding: 1): (buffer: Some(1), mutable: true),
(stage: Compute, group: 0, binding: 2): (buffer: Some(2), mutable: true),
}
)

View file

@ -61,8 +61,8 @@ type Particles = struct {
};
[[group(0), binding(0)]] var<uniform> params : SimParams;
[[group(0), binding(1)]] var<storage_buffer> particlesA : Particles;
[[group(0), binding(2)]] var<storage_buffer> particlesB : Particles;
[[group(0), binding(1)]] var<storage> particlesA : Particles;
[[group(0), binding(2)]] var<storage> particlesB : Particles;
[[builtin(global_invocation_id)]] var gl_GlobalInvocationID : vec3<u32>;

View file

@ -1,6 +1,6 @@
(
metal_bindings: {
(group: 0, binding: 0): (texture: Some(0)),
(group: 0, binding: 1): (sampler: Some(0)),
(stage: Fragment, group: 0, binding: 0): (texture: Some(0)),
(stage: Fragment, group: 0, binding: 1): (sampler: Some(0)),
}
)

View file

@ -14,8 +14,8 @@ fn main() -> void {
# fragment
[[location(0)]] var<in> v_uv : vec2<f32>;
[[group(0), binding(0)]] var<uniform> u_texture : texture_sampled_2d<f32>;
[[group(0), binding(1)]] var<uniform> u_sampler : sampler;
[[group(0), binding(0)]] var u_texture : texture_sampled_2d<f32>;
[[group(0), binding(1)]] var u_sampler : sampler;
[[location(0)]] var<out> o_color : vec4<f32>;
[[stage(fragment)]]

View file

@ -48,6 +48,7 @@
class: Input,
binding: Some(Location(0)),
ty: 1,
init: None,
interpolation: None,
storage_access: (
bits: 0,
@ -58,6 +59,7 @@
class: Output,
binding: Some(Location(0)),
ty: 2,
init: None,
interpolation: None,
storage_access: (
bits: 0,
@ -71,7 +73,7 @@
workgroup_size: (0, 0, 0),
function: (
name: Some("main"),
parameter_types: [],
arguments: [],
return_type: None,
global_usage: [
(
@ -85,7 +87,7 @@
(
name: Some("w"),
ty: 3,
init: Some(3),
init: Some(1),
),
],
expressions: [

View file

@ -14,13 +14,13 @@ fn load_wgsl(name: &str) -> naga::Module {
fn load_spv(name: &str) -> naga::Module {
let path = format!("{}/test-data/spv/{}", env!("CARGO_MANIFEST_DIR"), name);
let input = std::fs::read(path).unwrap();
naga::front::spv::parse_u8_slice(&input).unwrap()
naga::front::spv::parse_u8_slice(&input, &Default::default()).unwrap()
}
#[cfg(feature = "glsl-in")]
fn load_glsl(name: &str, entry: &str, stage: naga::ShaderStage) -> naga::Module {
let input = load_test_data(name);
naga::front::glsl::parse_str(&input, entry, stage).unwrap()
naga::front::glsl::parse_str(&input, entry, stage, Default::default()).unwrap()
}
#[cfg(feature = "wgsl-in")]
@ -34,6 +34,7 @@ fn convert_quad() {
let mut binding_map = msl::BindingMap::default();
binding_map.insert(
msl::BindSource {
stage: naga::ShaderStage::Fragment,
group: 0,
binding: 0,
},
@ -46,6 +47,7 @@ fn convert_quad() {
);
binding_map.insert(
msl::BindSource {
stage: naga::ShaderStage::Fragment,
group: 0,
binding: 1,
},
@ -57,9 +59,11 @@ fn convert_quad() {
},
);
let options = msl::Options {
binding_map: &binding_map,
lang_version: (1, 0),
spirv_cross_compatibility: false,
binding_map,
};
msl::write_string(&module, options).unwrap();
msl::write_string(&module, &options).unwrap();
}
}
@ -74,6 +78,7 @@ fn convert_boids() {
let mut binding_map = msl::BindingMap::default();
binding_map.insert(
msl::BindSource {
stage: naga::ShaderStage::Compute,
group: 0,
binding: 0,
},
@ -86,6 +91,7 @@ fn convert_boids() {
);
binding_map.insert(
msl::BindSource {
stage: naga::ShaderStage::Compute,
group: 0,
binding: 1,
},
@ -98,6 +104,7 @@ fn convert_boids() {
);
binding_map.insert(
msl::BindSource {
stage: naga::ShaderStage::Compute,
group: 0,
binding: 2,
},
@ -109,9 +116,11 @@ fn convert_boids() {
},
);
let options = msl::Options {
binding_map: &binding_map,
lang_version: (1, 0),
spirv_cross_compatibility: false,
binding_map,
};
msl::write_string(&module, options).unwrap();
msl::write_string(&module, &options).unwrap();
}
}
@ -129,6 +138,7 @@ fn convert_cube() {
let mut binding_map = msl::BindingMap::default();
binding_map.insert(
msl::BindSource {
stage: naga::ShaderStage::Vertex,
group: 0,
binding: 0,
},
@ -141,6 +151,7 @@ fn convert_cube() {
);
binding_map.insert(
msl::BindSource {
stage: naga::ShaderStage::Fragment,
group: 0,
binding: 1,
},
@ -153,6 +164,7 @@ fn convert_cube() {
);
binding_map.insert(
msl::BindSource {
stage: naga::ShaderStage::Fragment,
group: 0,
binding: 2,
},
@ -164,10 +176,12 @@ fn convert_cube() {
},
);
let options = msl::Options {
binding_map: &binding_map,
lang_version: (1, 0),
spirv_cross_compatibility: false,
binding_map,
};
msl::write_string(&vs, options).unwrap();
msl::write_string(&fs, options).unwrap();
msl::write_string(&vs, &options).unwrap();
msl::write_string(&fs, &options).unwrap();
}
}

View file

@ -40,15 +40,33 @@ fn test_rosetta(dir_name: &str) {
#[cfg(feature = "glsl-in")]
{
if let Ok(input) = fs::read_to_string(dir_path.join("x.vert")) {
let module = glsl::parse_str(&input, "main", naga::ShaderStage::Vertex).unwrap();
let module = glsl::parse_str(
&input,
"main",
naga::ShaderStage::Vertex,
Default::default(),
)
.unwrap();
check("vert", &module, &expected);
}
if let Ok(input) = fs::read_to_string(dir_path.join("x.frag")) {
let module = glsl::parse_str(&input, "main", naga::ShaderStage::Fragment).unwrap();
let module = glsl::parse_str(
&input,
"main",
naga::ShaderStage::Fragment,
Default::default(),
)
.unwrap();
check("frag", &module, &expected);
}
if let Ok(input) = fs::read_to_string(dir_path.join("x.comp")) {
let module = glsl::parse_str(&input, "main", naga::ShaderStage::Compute).unwrap();
let module = glsl::parse_str(
&input,
"main",
naga::ShaderStage::Compute,
Default::default(),
)
.unwrap();
check("comp", &module, &expected);
}
}