Bug 1851568 - Improve validation of tail call result types. r=bvisness

Differential Revision: https://phabricator.services.mozilla.com/D187501
This commit is contained in:
Yury Delendik 2023-09-07 12:29:34 +00:00
parent e3cea5c34e
commit fd29b97c09
9 changed files with 99 additions and 57 deletions

View file

@ -0,0 +1,20 @@
wasmFailValidateText(`(module
(func (result i32 f64)
i32.const 1
f64.const 2.0
)
(func (export "f") (result f64)
return_call 0
)
)`, /type mismatch/);
wasmFailValidateText(`(module
(func (result i32 f64)
i32.const 1
f64.const 2.0
)
(func (export "f") (result f32 i32 f64)
f32.const 3.14
return_call 0
)
)`, /type mismatch/);

View file

@ -168,3 +168,53 @@ let fns = i.exports;
assertEq(fns.churn(800), -575895114); assertEq(fns.churn(800), -575895114);
assertEq(fns.churn(1200), -1164697516); assertEq(fns.churn(1200), -1164697516);
wasmValidateText(`(module
(rec
(type $s1 (sub (struct i32)))
(type $s2 (sub $s1 (struct i32 f32)))
)
(func (result (ref $s2))
struct.new_default $s2
)
(func (export "f") (result (ref $s1))
return_call 0
)
)`);
wasmFailValidateText(`(module
(rec
(type $s1 (sub (struct i32)))
(type $s2 (sub $s1 (struct i32 f32)))
)
(func (result (ref $s1))
struct.new_default $s1
)
(func (export "f") (result (ref $s2))
return_call 0
)
)`, /type mismatch/);
wasmValidateText(`(module
(rec
(type $s1 (sub (struct i32)))
(type $s2 (sub $s1 (struct i32 f32)))
)
(type $t (func (result (ref $s2))))
(func (export "f") (param (ref $t)) (result (ref $s1))
local.get 0
return_call_ref $t
)
)`);
wasmFailValidateText(`(module
(rec
(type $s1 (sub (struct i32)))
(type $s2 (sub $s1 (struct i32 f32)))
)
(type $t (func (result (ref $s1))))
(func (export "f") (param (ref $t)) (result (ref $s2))
local.get 0
return_call_ref $t
)
)`, /type mismatch/);

View file

@ -143,7 +143,7 @@ wasmFailValidateText(
(table 0 anyfunc) (table 0 anyfunc)
(func $type-void-vs-num (result i32) (func $type-void-vs-num (result i32)
(i32.eqz (return_call_indirect (type 0) (i32.const 0)))))`, (i32.eqz (return_call_indirect (type 0) (i32.const 0)))))`,
/popping value from empty stack/); /type mismatch: expected 1 values, got 0 values/);
wasmFailValidateText( wasmFailValidateText(
`(module `(module
@ -151,7 +151,7 @@ wasmFailValidateText(
(table 0 anyfunc) (table 0 anyfunc)
(func $type-num-vs-num (func $type-num-vs-num
(i32.eqz (return_call_indirect (type 0) (i32.const 0)))))`, (i32.eqz (return_call_indirect (type 0) (i32.const 0)))))`,
/unused values not explicitly dropped/); /type mismatch: expected 0 values, got 1 values/);
wasmFailValidateText( wasmFailValidateText(
`(module `(module

View file

@ -97,7 +97,7 @@ wasmFailValidateText(
(func $type-void-vs-num (result i32) (func $type-void-vs-num (result i32)
(return_call 1) (i32.const 0)) (return_call 1) (i32.const 0))
(func))`, (func))`,
/popping value from empty stack/); /type mismatch: expected 1 values, got 0 values/);
wasmFailValidateText( wasmFailValidateText(
`(module `(module

View file

@ -52,12 +52,13 @@ check_stub1: {
var ins = wasmEvalText(`(module var ins = wasmEvalText(`(module
(import "" "fac-acc" (func $fac-acc (param i64 i64) (result i64))) (import "" "fac-acc" (func $fac-acc (param i64 i64) (result i64)))
(type $ty (func (param i64 i64) (result i64))) (type $ty (func (param i64 i64) (result i64)))
(type $tz (func (param i64) (result i64)))
(table $t 1 1 funcref) (table $t 1 1 funcref)
(func $f (export "fac") (param i64) (result i64) (func $f (export "fac") (param i64) (result i64)
local.get 0 local.get 0
i64.const 1 i64.const 1
i32.const 0 i32.const 0
return_call_indirect $t return_call_indirect $t (type $tz)
) )
(elem $t (i32.const 0) $fac-acc) (elem $t (i32.const 0) $fac-acc)

View file

@ -4817,8 +4817,7 @@ bool BaseCompiler::emitCall() {
bool BaseCompiler::emitReturnCall() { bool BaseCompiler::emitReturnCall() {
uint32_t funcIndex; uint32_t funcIndex;
BaseNothingVector args_{}; BaseNothingVector args_{};
BaseNothingVector unused_values{}; if (!iter_.readReturnCall(&funcIndex, &args_)) {
if (!iter_.readReturnCall(&funcIndex, &args_, &unused_values)) {
return false; return false;
} }
@ -4934,9 +4933,8 @@ bool BaseCompiler::emitReturnCallIndirect() {
uint32_t tableIndex; uint32_t tableIndex;
Nothing callee_; Nothing callee_;
BaseNothingVector args_{}; BaseNothingVector args_{};
BaseNothingVector unused_values{};
if (!iter_.readReturnCallIndirect(&funcTypeIndex, &tableIndex, &callee_, if (!iter_.readReturnCallIndirect(&funcTypeIndex, &tableIndex, &callee_,
&args_, &unused_values)) { &args_)) {
return false; return false;
} }
@ -5042,9 +5040,7 @@ bool BaseCompiler::emitReturnCallRef() {
const FuncType* funcType; const FuncType* funcType;
Nothing unused_callee; Nothing unused_callee;
BaseNothingVector unused_args{}; BaseNothingVector unused_args{};
BaseNothingVector unused_values{}; if (!iter_.readReturnCallRef(&funcType, &unused_callee, &unused_args)) {
if (!iter_.readReturnCallRef(&funcType, &unused_callee, &unused_args,
&unused_values)) {
return false; return false;
} }

View file

@ -5101,8 +5101,7 @@ static bool EmitReturnCall(FunctionCompiler& f) {
uint32_t funcIndex; uint32_t funcIndex;
DefVector args; DefVector args;
DefVector unused_values; if (!f.iter().readReturnCall(&funcIndex, &args)) {
if (!f.iter().readReturnCall(&funcIndex, &args, &unused_values)) {
return false; return false;
} }
@ -5142,9 +5141,8 @@ static bool EmitReturnCallIndirect(FunctionCompiler& f) {
uint32_t tableIndex; uint32_t tableIndex;
MDefinition* callee; MDefinition* callee;
DefVector args; DefVector args;
DefVector unused_values;
if (!f.iter().readReturnCallIndirect(&funcTypeIndex, &tableIndex, &callee, if (!f.iter().readReturnCallIndirect(&funcTypeIndex, &tableIndex, &callee,
&args, &unused_values)) { &args)) {
return false; return false;
} }
@ -5176,9 +5174,8 @@ static bool EmitReturnCallRef(FunctionCompiler& f) {
const FuncType* funcType; const FuncType* funcType;
MDefinition* callee; MDefinition* callee;
DefVector args; DefVector args;
DefVector unused_values;
if (!f.iter().readReturnCallRef(&funcType, &callee, &args, &unused_values)) { if (!f.iter().readReturnCallRef(&funcType, &callee, &args)) {
return false; return false;
} }

View file

@ -690,12 +690,10 @@ class MOZ_STACK_CLASS OpIter : private Policy {
ValueVector* argValues); ValueVector* argValues);
#ifdef ENABLE_WASM_TAIL_CALLS #ifdef ENABLE_WASM_TAIL_CALLS
[[nodiscard]] bool readReturnCall(uint32_t* funcTypeIndex, [[nodiscard]] bool readReturnCall(uint32_t* funcTypeIndex,
ValueVector* argValues, ValueVector* argValues);
ValueVector* values);
[[nodiscard]] bool readReturnCallIndirect(uint32_t* funcTypeIndex, [[nodiscard]] bool readReturnCallIndirect(uint32_t* funcTypeIndex,
uint32_t* tableIndex, Value* callee, uint32_t* tableIndex, Value* callee,
ValueVector* argValues, ValueVector* argValues);
ValueVector* values);
#endif #endif
#ifdef ENABLE_WASM_FUNCTION_REFERENCES #ifdef ENABLE_WASM_FUNCTION_REFERENCES
[[nodiscard]] bool readCallRef(const FuncType** funcType, Value* callee, [[nodiscard]] bool readCallRef(const FuncType** funcType, Value* callee,
@ -703,8 +701,7 @@ class MOZ_STACK_CLASS OpIter : private Policy {
# ifdef ENABLE_WASM_TAIL_CALLS # ifdef ENABLE_WASM_TAIL_CALLS
[[nodiscard]] bool readReturnCallRef(const FuncType** funcType, Value* callee, [[nodiscard]] bool readReturnCallRef(const FuncType** funcType, Value* callee,
ValueVector* argValues, ValueVector* argValues);
ValueVector* values);
# endif # endif
#endif #endif
[[nodiscard]] bool readOldCallDirect(uint32_t numFuncImports, [[nodiscard]] bool readOldCallDirect(uint32_t numFuncImports,
@ -2462,8 +2459,7 @@ inline bool OpIter<Policy>::readCall(uint32_t* funcTypeIndex,
#ifdef ENABLE_WASM_TAIL_CALLS #ifdef ENABLE_WASM_TAIL_CALLS
template <typename Policy> template <typename Policy>
inline bool OpIter<Policy>::readReturnCall(uint32_t* funcTypeIndex, inline bool OpIter<Policy>::readReturnCall(uint32_t* funcTypeIndex,
ValueVector* argValues, ValueVector* argValues) {
ValueVector* values) {
MOZ_ASSERT(Classify(op_) == OpKind::ReturnCall); MOZ_ASSERT(Classify(op_) == OpKind::ReturnCall);
if (!readVarU32(funcTypeIndex)) { if (!readVarU32(funcTypeIndex)) {
@ -2480,15 +2476,11 @@ inline bool OpIter<Policy>::readReturnCall(uint32_t* funcTypeIndex,
return false; return false;
} }
if (!push(ResultType::Vector(funcType.results()))) { // Check if callee results are subtypes of caller's.
return false;
}
Control& body = controlStack_[0]; Control& body = controlStack_[0];
MOZ_ASSERT(body.kind() == LabelKind::Body); MOZ_ASSERT(body.kind() == LabelKind::Body);
if (!checkIsSubtypeOf(ResultType::Vector(funcType.results()),
// Pop function results as the instruction will cause a return. body.resultType())) {
if (!popWithType(body.resultType(), values)) {
return false; return false;
} }
@ -2549,8 +2541,7 @@ template <typename Policy>
inline bool OpIter<Policy>::readReturnCallIndirect(uint32_t* funcTypeIndex, inline bool OpIter<Policy>::readReturnCallIndirect(uint32_t* funcTypeIndex,
uint32_t* tableIndex, uint32_t* tableIndex,
Value* callee, Value* callee,
ValueVector* argValues, ValueVector* argValues) {
ValueVector* values) {
MOZ_ASSERT(Classify(op_) == OpKind::ReturnCallIndirect); MOZ_ASSERT(Classify(op_) == OpKind::ReturnCallIndirect);
MOZ_ASSERT(funcTypeIndex != tableIndex); MOZ_ASSERT(funcTypeIndex != tableIndex);
@ -2589,15 +2580,11 @@ inline bool OpIter<Policy>::readReturnCallIndirect(uint32_t* funcTypeIndex,
return false; return false;
} }
if (!push(ResultType::Vector(funcType.results()))) { // Check if callee results are subtypes of caller's.
return false;
}
Control& body = controlStack_[0]; Control& body = controlStack_[0];
MOZ_ASSERT(body.kind() == LabelKind::Body); MOZ_ASSERT(body.kind() == LabelKind::Body);
if (!checkIsSubtypeOf(ResultType::Vector(funcType.results()),
// Pop function results as the instruction will cause a return. body.resultType())) {
if (!popWithType(body.resultType(), values)) {
return false; return false;
} }
@ -2636,8 +2623,7 @@ inline bool OpIter<Policy>::readCallRef(const FuncType** funcType,
template <typename Policy> template <typename Policy>
inline bool OpIter<Policy>::readReturnCallRef(const FuncType** funcType, inline bool OpIter<Policy>::readReturnCallRef(const FuncType** funcType,
Value* callee, Value* callee,
ValueVector* argValues, ValueVector* argValues) {
ValueVector* values) {
MOZ_ASSERT(Classify(op_) == OpKind::ReturnCallRef); MOZ_ASSERT(Classify(op_) == OpKind::ReturnCallRef);
uint32_t funcTypeIndex; uint32_t funcTypeIndex;
@ -2656,15 +2642,11 @@ inline bool OpIter<Policy>::readReturnCallRef(const FuncType** funcType,
return false; return false;
} }
if (!push(ResultType::Vector((*funcType)->results()))) { // Check if callee results are subtypes of caller's.
return false;
}
Control& body = controlStack_[0]; Control& body = controlStack_[0];
MOZ_ASSERT(body.kind() == LabelKind::Body); MOZ_ASSERT(body.kind() == LabelKind::Body);
if (!checkIsSubtypeOf(ResultType::Vector((*funcType)->results()),
// Pop function results as the instruction will cause a return. body.resultType())) {
if (!popWithType(body.resultType(), values)) {
return false; return false;
} }

View file

@ -223,8 +223,7 @@ static bool DecodeFunctionBodyExprs(const ModuleEnvironment& env,
} }
uint32_t unusedIndex; uint32_t unusedIndex;
NothingVector unusedArgs{}; NothingVector unusedArgs{};
NothingVector unusedValues{}; CHECK(iter.readReturnCall(&unusedIndex, &unusedArgs));
CHECK(iter.readReturnCall(&unusedIndex, &unusedArgs, &unusedValues));
} }
case uint16_t(Op::ReturnCallIndirect): { case uint16_t(Op::ReturnCallIndirect): {
if (!env.tailCallsEnabled()) { if (!env.tailCallsEnabled()) {
@ -232,9 +231,8 @@ static bool DecodeFunctionBodyExprs(const ModuleEnvironment& env,
} }
uint32_t unusedIndex, unusedIndex2; uint32_t unusedIndex, unusedIndex2;
NothingVector unusedArgs{}; NothingVector unusedArgs{};
NothingVector unusedValues{};
CHECK(iter.readReturnCallIndirect(&unusedIndex, &unusedIndex2, &nothing, CHECK(iter.readReturnCallIndirect(&unusedIndex, &unusedIndex2, &nothing,
&unusedArgs, &unusedValues)); &unusedArgs));
} }
#endif #endif
#ifdef ENABLE_WASM_FUNCTION_REFERENCES #ifdef ENABLE_WASM_FUNCTION_REFERENCES
@ -253,9 +251,7 @@ static bool DecodeFunctionBodyExprs(const ModuleEnvironment& env,
} }
const FuncType* unusedType; const FuncType* unusedType;
NothingVector unusedArgs{}; NothingVector unusedArgs{};
NothingVector unusedValues{}; CHECK(iter.readReturnCallRef(&unusedType, &nothing, &unusedArgs));
CHECK(iter.readReturnCallRef(&unusedType, &nothing, &unusedArgs,
&unusedValues));
} }
# endif # endif
#endif #endif