forked from mirrors/gecko-dev
--HG-- extra : rebase_source : 0f772aa20c592a46cfdbc83a64b218a4568ff28d extra : histedit_source : 7148969780d2ba4fb5aeecdcb413855a4babef33
894 lines
28 KiB
Python
894 lines
28 KiB
Python
"""
|
|
Type Inference
|
|
"""
|
|
from .typevar import TypeVar
|
|
from .ast import Def, Var
|
|
from copy import copy
|
|
from itertools import product
|
|
|
|
try:
|
|
from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set # noqa
|
|
from typing import Iterable, List, Any, TypeVar as MTypeVar # noqa
|
|
from typing import cast
|
|
from .xform import Rtl, XForm # noqa
|
|
from .ast import Expr # noqa
|
|
from .typevar import TypeSet # noqa
|
|
if TYPE_CHECKING:
|
|
T = MTypeVar('T')
|
|
TypeMap = Dict[TypeVar, TypeVar]
|
|
VarTyping = Dict[Var, TypeVar]
|
|
except ImportError:
|
|
TYPE_CHECKING = False
|
|
pass
|
|
|
|
|
|
class TypeConstraint(object):
|
|
"""
|
|
Base class for all runtime-emittable type constraints.
|
|
"""
|
|
|
|
def __init__(self, tv, tc):
|
|
# type: (TypeVar, Union[TypeVar, TypeSet]) -> None
|
|
"""
|
|
Abstract "constructor" for linters
|
|
"""
|
|
assert False, "Abstract"
|
|
|
|
def translate(self, m):
|
|
# type: (Union[TypeEnv, TypeMap]) -> TypeConstraint
|
|
"""
|
|
Translate any TypeVars in the constraint according to the map or
|
|
TypeEnv m
|
|
"""
|
|
def translate_one(a):
|
|
# type: (Any) -> Any
|
|
if (isinstance(a, TypeVar)):
|
|
return m[a] if isinstance(m, TypeEnv) else subst(a, m)
|
|
return a
|
|
|
|
res = None # type: TypeConstraint
|
|
res = self.__class__(*tuple(map(translate_one, self._args())))
|
|
return res
|
|
|
|
def __eq__(self, other):
|
|
# type: (object) -> bool
|
|
if (not isinstance(other, self.__class__)):
|
|
return False
|
|
|
|
assert isinstance(other, TypeConstraint) # help MyPy figure out other
|
|
return self._args() == other._args()
|
|
|
|
def is_concrete(self):
|
|
# type: () -> bool
|
|
"""
|
|
Return true iff all typevars in the constraint are singletons.
|
|
"""
|
|
return [] == list(filter(lambda x: x.singleton_type() is None,
|
|
self.tvs()))
|
|
|
|
def __hash__(self):
|
|
# type: () -> int
|
|
return hash(self._args())
|
|
|
|
def _args(self):
|
|
# type: () -> Tuple[Any,...]
|
|
"""
|
|
Return a tuple with the exact arguments passed to __init__ to create
|
|
this object.
|
|
"""
|
|
assert False, "Abstract"
|
|
|
|
def tvs(self):
|
|
# type: () -> Iterable[TypeVar]
|
|
"""
|
|
Return the typevars contained in this constraint.
|
|
"""
|
|
return list(filter(lambda x: isinstance(x, TypeVar), self._args()))
|
|
|
|
def is_trivial(self):
|
|
# type: () -> bool
|
|
"""
|
|
Return true if this constrain is statically decidable.
|
|
"""
|
|
assert False, "Abstract"
|
|
|
|
def eval(self):
|
|
# type: () -> bool
|
|
"""
|
|
Evaluate this constraint. Should only be called when the constraint has
|
|
been translated to concrete types.
|
|
"""
|
|
assert False, "Abstract"
|
|
|
|
def __repr__(self):
|
|
# type: () -> str
|
|
return (self.__class__.__name__ + '(' +
|
|
', '.join(map(str, self._args())) + ')')
|
|
|
|
|
|
class TypesEqual(TypeConstraint):
|
|
"""
|
|
Constraint specifying that two derived type vars must have the same runtime
|
|
type.
|
|
"""
|
|
def __init__(self, tv1, tv2):
|
|
# type: (TypeVar, TypeVar) -> None
|
|
(self.tv1, self.tv2) = sorted([tv1, tv2], key=repr)
|
|
|
|
def _args(self):
|
|
# type: () -> Tuple[Any,...]
|
|
""" See TypeConstraint._args() """
|
|
return (self.tv1, self.tv2)
|
|
|
|
def is_trivial(self):
|
|
# type: () -> bool
|
|
""" See TypeConstraint.is_trivial() """
|
|
return self.tv1 == self.tv2 or self.is_concrete()
|
|
|
|
def eval(self):
|
|
# type: () -> bool
|
|
""" See TypeConstraint.eval() """
|
|
assert self.is_concrete()
|
|
return self.tv1.singleton_type() == self.tv2.singleton_type()
|
|
|
|
|
|
class InTypeset(TypeConstraint):
|
|
"""
|
|
Constraint specifying that a type var must belong to some typeset.
|
|
"""
|
|
def __init__(self, tv, ts):
|
|
# type: (TypeVar, TypeSet) -> None
|
|
assert not tv.is_derived and tv.name.startswith("typeof_")
|
|
self.tv = tv
|
|
self.ts = ts
|
|
|
|
def _args(self):
|
|
# type: () -> Tuple[Any,...]
|
|
""" See TypeConstraint._args() """
|
|
return (self.tv, self.ts)
|
|
|
|
def is_trivial(self):
|
|
# type: () -> bool
|
|
""" See TypeConstraint.is_trivial() """
|
|
tv_ts = self.tv.get_typeset().copy()
|
|
|
|
# Trivially True
|
|
if (tv_ts.issubset(self.ts)):
|
|
return True
|
|
|
|
# Trivially false
|
|
tv_ts &= self.ts
|
|
if (tv_ts.size() == 0):
|
|
return True
|
|
|
|
return self.is_concrete()
|
|
|
|
def eval(self):
|
|
# type: () -> bool
|
|
""" See TypeConstraint.eval() """
|
|
assert self.is_concrete()
|
|
return self.tv.get_typeset().issubset(self.ts)
|
|
|
|
|
|
class WiderOrEq(TypeConstraint):
|
|
"""
|
|
Constraint specifying that a type var tv1 must be wider than or equal to
|
|
type var tv2 at runtime. This requires that:
|
|
1) They have the same number of lanes
|
|
2) In a lane tv1 has at least as many bits as tv2.
|
|
"""
|
|
def __init__(self, tv1, tv2):
|
|
# type: (TypeVar, TypeVar) -> None
|
|
self.tv1 = tv1
|
|
self.tv2 = tv2
|
|
|
|
def _args(self):
|
|
# type: () -> Tuple[Any,...]
|
|
""" See TypeConstraint._args() """
|
|
return (self.tv1, self.tv2)
|
|
|
|
def is_trivial(self):
|
|
# type: () -> bool
|
|
""" See TypeConstraint.is_trivial() """
|
|
# Trivially true
|
|
if (self.tv1 == self.tv2):
|
|
return True
|
|
|
|
ts1 = self.tv1.get_typeset()
|
|
ts2 = self.tv2.get_typeset()
|
|
|
|
def set_wider_or_equal(s1, s2):
|
|
# type: (Set[int], Set[int]) -> bool
|
|
return len(s1) > 0 and len(s2) > 0 and min(s1) >= max(s2)
|
|
|
|
# Trivially True
|
|
if set_wider_or_equal(ts1.ints, ts2.ints) and\
|
|
set_wider_or_equal(ts1.floats, ts2.floats) and\
|
|
set_wider_or_equal(ts1.bools, ts2.bools):
|
|
return True
|
|
|
|
def set_narrower(s1, s2):
|
|
# type: (Set[int], Set[int]) -> bool
|
|
return len(s1) > 0 and len(s2) > 0 and min(s1) < max(s2)
|
|
|
|
# Trivially False
|
|
if set_narrower(ts1.ints, ts2.ints) and\
|
|
set_narrower(ts1.floats, ts2.floats) and\
|
|
set_narrower(ts1.bools, ts2.bools):
|
|
return True
|
|
|
|
# Trivially False
|
|
if len(ts1.lanes.intersection(ts2.lanes)) == 0:
|
|
return True
|
|
|
|
return self.is_concrete()
|
|
|
|
def eval(self):
|
|
# type: () -> bool
|
|
""" See TypeConstraint.eval() """
|
|
assert self.is_concrete()
|
|
typ1 = self.tv1.singleton_type()
|
|
typ2 = self.tv2.singleton_type()
|
|
|
|
return typ1.wider_or_equal(typ2)
|
|
|
|
|
|
class SameWidth(TypeConstraint):
|
|
"""
|
|
Constraint specifying that two types have the same width. E.g. i32x2 has
|
|
the same width as i64x1, i16x4, f32x2, f64, b1x64 etc.
|
|
"""
|
|
def __init__(self, tv1, tv2):
|
|
# type: (TypeVar, TypeVar) -> None
|
|
self.tv1 = tv1
|
|
self.tv2 = tv2
|
|
|
|
def _args(self):
|
|
# type: () -> Tuple[Any,...]
|
|
""" See TypeConstraint._args() """
|
|
return (self.tv1, self.tv2)
|
|
|
|
def is_trivial(self):
|
|
# type: () -> bool
|
|
""" See TypeConstraint.is_trivial() """
|
|
# Trivially true
|
|
if (self.tv1 == self.tv2):
|
|
return True
|
|
|
|
ts1 = self.tv1.get_typeset()
|
|
ts2 = self.tv2.get_typeset()
|
|
|
|
# Trivially False
|
|
if len(ts1.widths().intersection(ts2.widths())) == 0:
|
|
return True
|
|
|
|
return self.is_concrete()
|
|
|
|
def eval(self):
|
|
# type: () -> bool
|
|
""" See TypeConstraint.eval() """
|
|
assert self.is_concrete()
|
|
typ1 = self.tv1.singleton_type()
|
|
typ2 = self.tv2.singleton_type()
|
|
|
|
return (typ1.width() == typ2.width())
|
|
|
|
|
|
class TypeEnv(object):
|
|
"""
|
|
Class encapsulating the necessary book keeping for type inference.
|
|
:attribute type_map: dict holding the equivalence relations between tvs
|
|
:attribute constraints: a list of accumulated constraints - tuples
|
|
(tv1, tv2)) where tv1 and tv2 are equal
|
|
:attribute ranks: dictionary recording the (optional) ranks for tvs.
|
|
'rank' is a partial ordering on TVs based on their
|
|
origin. See comments in rank() and register().
|
|
:attribute vars: a set containing all known Vars
|
|
:attribute idx: counter used to get fresh ids
|
|
"""
|
|
|
|
RANK_SINGLETON = 5
|
|
RANK_INPUT = 4
|
|
RANK_INTERMEDIATE = 3
|
|
RANK_OUTPUT = 2
|
|
RANK_TEMP = 1
|
|
RANK_INTERNAL = 0
|
|
|
|
def __init__(self, arg=None):
|
|
# type: (Optional[Tuple[TypeMap, List[TypeConstraint]]]) -> None
|
|
self.ranks = {} # type: Dict[TypeVar, int]
|
|
self.vars = set() # type: Set[Var]
|
|
|
|
if arg is None:
|
|
self.type_map = {} # type: TypeMap
|
|
self.constraints = [] # type: List[TypeConstraint]
|
|
else:
|
|
self.type_map, self.constraints = arg
|
|
|
|
self.idx = 0
|
|
|
|
def __getitem__(self, arg):
|
|
# type: (Union[TypeVar, Var]) -> TypeVar
|
|
"""
|
|
Lookup the canonical representative for a Var/TypeVar.
|
|
"""
|
|
if (isinstance(arg, Var)):
|
|
assert arg in self.vars
|
|
tv = arg.get_typevar()
|
|
else:
|
|
assert (isinstance(arg, TypeVar))
|
|
tv = arg
|
|
|
|
while tv in self.type_map:
|
|
tv = self.type_map[tv]
|
|
|
|
if tv.is_derived:
|
|
tv = TypeVar.derived(self[tv.base], tv.derived_func)
|
|
return tv
|
|
|
|
def equivalent(self, tv1, tv2):
|
|
# type: (TypeVar, TypeVar) -> None
|
|
"""
|
|
Record a that the free tv1 is part of the same equivalence class as
|
|
tv2. The canonical representative of the merged class is tv2's
|
|
canonical representative.
|
|
"""
|
|
assert not tv1.is_derived
|
|
assert self[tv1] == tv1
|
|
|
|
# Make sure we don't create cycles
|
|
if tv2.is_derived:
|
|
assert self[tv2.base] != tv1
|
|
|
|
self.type_map[tv1] = tv2
|
|
|
|
def add_constraint(self, constr):
|
|
# type: (TypeConstraint) -> None
|
|
"""
|
|
Add a new constraint
|
|
"""
|
|
if (constr in self.constraints):
|
|
return
|
|
|
|
# InTypeset constraints can be expressed by constraining the typeset of
|
|
# a variable. No need to add them to self.constraints
|
|
if (isinstance(constr, InTypeset)):
|
|
self[constr.tv].constrain_types_by_ts(constr.ts)
|
|
return
|
|
|
|
self.constraints.append(constr)
|
|
|
|
def get_uid(self):
|
|
# type: () -> str
|
|
r = str(self.idx)
|
|
self.idx += 1
|
|
return r
|
|
|
|
def __repr__(self):
|
|
# type: () -> str
|
|
return self.dot()
|
|
|
|
def rank(self, tv):
|
|
# type: (TypeVar) -> int
|
|
"""
|
|
Get the rank of tv in the partial order. TVs directly associated with a
|
|
Var get their rank from the Var (see register()). Internally generated
|
|
non-derived TVs implicitly get the lowest rank (0). Derived variables
|
|
get their rank from their free typevar. Singletons have the highest
|
|
rank. TVs associated with vars in a source pattern have a higher rank
|
|
than TVs associated with temporary vars.
|
|
"""
|
|
default_rank = TypeEnv.RANK_INTERNAL if tv.singleton_type() is None \
|
|
else TypeEnv.RANK_SINGLETON
|
|
|
|
if tv.is_derived:
|
|
tv = tv.free_typevar()
|
|
|
|
return self.ranks.get(tv, default_rank)
|
|
|
|
def register(self, v):
|
|
# type: (Var) -> None
|
|
"""
|
|
Register a new Var v. This computes a rank for the associated TypeVar
|
|
for v, which is used to impose a partial order on type variables.
|
|
"""
|
|
self.vars.add(v)
|
|
|
|
if v.is_input():
|
|
r = TypeEnv.RANK_INPUT
|
|
elif v.is_intermediate():
|
|
r = TypeEnv.RANK_INTERMEDIATE
|
|
elif v.is_output():
|
|
r = TypeEnv.RANK_OUTPUT
|
|
else:
|
|
assert(v.is_temp())
|
|
r = TypeEnv.RANK_TEMP
|
|
|
|
self.ranks[v.get_typevar()] = r
|
|
|
|
def free_typevars(self):
|
|
# type: () -> List[TypeVar]
|
|
"""
|
|
Get the free typevars in the current type env.
|
|
"""
|
|
tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()])
|
|
tvs = tvs.union(set([self[v].free_typevar() for v in self.vars]))
|
|
# Filter out None here due to singleton type vars
|
|
return sorted(filter(lambda x: x is not None, tvs),
|
|
key=lambda x: x.name)
|
|
|
|
def normalize(self):
|
|
# type: () -> None
|
|
"""
|
|
Normalize by:
|
|
- collapsing any roots that don't correspond to a concrete TV AND
|
|
have a single TV derived from them or equivalent to them
|
|
|
|
E.g. if we have a root of the tree that looks like:
|
|
|
|
typeof_a typeof_b
|
|
\\ /
|
|
typeof_x
|
|
|
|
|
half_width(1)
|
|
|
|
|
1
|
|
|
|
we want to collapse the linear path between 1 and typeof_x. The
|
|
resulting graph is:
|
|
|
|
typeof_a typeof_b
|
|
\\ /
|
|
typeof_x
|
|
"""
|
|
source_tvs = set([v.get_typevar() for v in self.vars])
|
|
children = {} # type: Dict[TypeVar, Set[TypeVar]]
|
|
for v in self.type_map.values():
|
|
if not v.is_derived:
|
|
continue
|
|
|
|
t = v.free_typevar()
|
|
s = children.get(t, set())
|
|
s.add(v)
|
|
children[t] = s
|
|
|
|
for (a, b) in self.type_map.items():
|
|
s = children.get(b, set())
|
|
s.add(a)
|
|
children[b] = s
|
|
|
|
for r in self.free_typevars():
|
|
while (r not in source_tvs and r in children and
|
|
len(children[r]) == 1):
|
|
child = list(children[r])[0]
|
|
if child in self.type_map:
|
|
assert self.type_map[child] == r
|
|
del self.type_map[child]
|
|
|
|
r = child
|
|
|
|
def extract(self):
|
|
# type: () -> TypeEnv
|
|
"""
|
|
Extract a clean type environment from self, that only mentions
|
|
TVs associated with real variables
|
|
"""
|
|
vars_tvs = set([v.get_typevar() for v in self.vars])
|
|
new_type_map = {tv: self[tv] for tv in vars_tvs if tv != self[tv]}
|
|
|
|
new_constraints = [] # type: List[TypeConstraint]
|
|
for constr in self.constraints:
|
|
constr = constr.translate(self)
|
|
|
|
if constr.is_trivial() or constr in new_constraints:
|
|
continue
|
|
|
|
# Sanity: translated constraints should refer to only real vars
|
|
for arg in constr._args():
|
|
if (not isinstance(arg, TypeVar)):
|
|
continue
|
|
|
|
arg_free_tv = arg.free_typevar()
|
|
assert arg_free_tv is None or arg_free_tv in vars_tvs
|
|
|
|
new_constraints.append(constr)
|
|
|
|
# Sanity: translated typemap should refer to only real vars
|
|
for (k, v) in new_type_map.items():
|
|
assert k in vars_tvs
|
|
assert v.free_typevar() is None or v.free_typevar() in vars_tvs
|
|
|
|
t = TypeEnv()
|
|
t.type_map = new_type_map
|
|
t.constraints = new_constraints
|
|
# ranks and vars contain only TVs associated with real vars
|
|
t.ranks = copy(self.ranks)
|
|
t.vars = copy(self.vars)
|
|
return t
|
|
|
|
def concrete_typings(self):
|
|
# type: () -> Iterable[VarTyping]
|
|
"""
|
|
Return an iterable over all possible concrete typings permitted by this
|
|
TypeEnv.
|
|
"""
|
|
free_tvs = self.free_typevars()
|
|
free_tv_iters = [tv.get_typeset().concrete_types() for tv in free_tvs]
|
|
for concrete_types in product(*free_tv_iters):
|
|
# Build type substitutions for all free vars
|
|
m = {tv: TypeVar.singleton(typ)
|
|
for (tv, typ) in zip(free_tvs, concrete_types)}
|
|
|
|
concrete_var_map = {v: subst(self[v.get_typevar()], m)
|
|
for v in self.vars}
|
|
|
|
# Check if constraints are satisfied for this typing
|
|
failed = None
|
|
for constr in self.constraints:
|
|
concrete_constr = constr.translate(m)
|
|
if not concrete_constr.eval():
|
|
failed = concrete_constr
|
|
break
|
|
|
|
if (failed is not None):
|
|
continue
|
|
|
|
yield concrete_var_map
|
|
|
|
def permits(self, concrete_typing):
|
|
# type: (VarTyping) -> bool
|
|
"""
|
|
Return true iff this TypeEnv permits the (possibly partial) concrete
|
|
variable type mapping concrete_typing.
|
|
"""
|
|
# Each variable has a concrete type, that is a subset of its inferred
|
|
# typeset.
|
|
for (v, typ) in concrete_typing.items():
|
|
assert typ.singleton_type() is not None
|
|
if not typ.get_typeset().issubset(self[v].get_typeset()):
|
|
return False
|
|
|
|
m = {self[v]: typ for (v, typ) in concrete_typing.items()}
|
|
|
|
# Constraints involving vars in concrete_typing are satisfied
|
|
for constr in self.constraints:
|
|
try:
|
|
# If the constraint includes only vars in concrete_typing, we
|
|
# can translate it using m. Otherwise we encounter a KeyError
|
|
# and ignore it
|
|
constr = constr.translate(m)
|
|
if not constr.eval():
|
|
return False
|
|
except KeyError:
|
|
pass
|
|
|
|
return True
|
|
|
|
def dot(self):
|
|
# type: () -> str
|
|
"""
|
|
Return a representation of self as a graph in dot format.
|
|
Nodes correspond to TypeVariables.
|
|
Dotted edges correspond to equivalences between TVS
|
|
Solid edges correspond to derivation relations between TVs.
|
|
Dashed edges correspond to equivalence constraints.
|
|
"""
|
|
def label(s):
|
|
# type: (TypeVar) -> str
|
|
return "\"" + str(s) + "\""
|
|
|
|
# Add all registered TVs (as some of them may be singleton nodes not
|
|
# appearing in the graph
|
|
nodes = set() # type: Set[TypeVar]
|
|
edges = set() # type: Set[Tuple[TypeVar, TypeVar, str, str, Optional[str]]] # noqa
|
|
|
|
def add_nodes(*args):
|
|
# type: (*TypeVar) -> None
|
|
for tv in args:
|
|
nodes.add(tv)
|
|
while (tv.is_derived):
|
|
nodes.add(tv.base)
|
|
edges.add((tv, tv.base, "solid", "forward",
|
|
tv.derived_func))
|
|
tv = tv.base
|
|
|
|
for v in self.vars:
|
|
add_nodes(v.get_typevar())
|
|
|
|
for (tv1, tv2) in self.type_map.items():
|
|
# Add all intermediate TVs appearing in edges
|
|
add_nodes(tv1, tv2)
|
|
edges.add((tv1, tv2, "dotted", "forward", None))
|
|
|
|
for constr in self.constraints:
|
|
if isinstance(constr, TypesEqual):
|
|
add_nodes(constr.tv1, constr.tv2)
|
|
edges.add((constr.tv1, constr.tv2, "dashed", "none", "equal"))
|
|
elif isinstance(constr, WiderOrEq):
|
|
add_nodes(constr.tv1, constr.tv2)
|
|
edges.add((constr.tv1, constr.tv2, "dashed", "forward", ">="))
|
|
elif isinstance(constr, SameWidth):
|
|
add_nodes(constr.tv1, constr.tv2)
|
|
edges.add((constr.tv1, constr.tv2, "dashed", "none",
|
|
"same_width"))
|
|
else:
|
|
assert False, "Can't display constraint {}".format(constr)
|
|
|
|
root_nodes = set([x for x in nodes
|
|
if x not in self.type_map and not x.is_derived])
|
|
|
|
r = "digraph {\n"
|
|
for n in nodes:
|
|
r += label(n)
|
|
if n in root_nodes:
|
|
r += "[xlabel=\"{}\"]".format(self[n].get_typeset())
|
|
r += ";\n"
|
|
|
|
for (n1, n2, style, direction, elabel) in edges:
|
|
e = label(n1) + "->" + label(n2)
|
|
e += "[style={},dir={}".format(style, direction)
|
|
|
|
if elabel is not None:
|
|
e += ",label=\"{}\"".format(elabel)
|
|
e += "];\n"
|
|
|
|
r += e
|
|
r += "}"
|
|
|
|
return r
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
TypingError = str
|
|
TypingOrError = Union[TypeEnv, TypingError]
|
|
|
|
|
|
def get_error(typing_or_err):
|
|
# type: (TypingOrError) -> Optional[TypingError]
|
|
"""
|
|
Helper function to appease mypy when checking the result of typing.
|
|
"""
|
|
if isinstance(typing_or_err, str):
|
|
if (TYPE_CHECKING):
|
|
return cast(TypingError, typing_or_err)
|
|
else:
|
|
return typing_or_err
|
|
else:
|
|
return None
|
|
|
|
|
|
def get_type_env(typing_or_err):
|
|
# type: (TypingOrError) -> TypeEnv
|
|
"""
|
|
Helper function to appease mypy when checking the result of typing.
|
|
"""
|
|
assert isinstance(typing_or_err, TypeEnv), \
|
|
"Unexpected error: {}".format(typing_or_err)
|
|
|
|
if (TYPE_CHECKING):
|
|
return cast(TypeEnv, typing_or_err)
|
|
else:
|
|
return typing_or_err
|
|
|
|
|
|
def subst(tv, tv_map):
|
|
# type: (TypeVar, TypeMap) -> TypeVar
|
|
"""
|
|
Perform substition on the input tv using the TypeMap tv_map.
|
|
"""
|
|
if tv in tv_map:
|
|
return tv_map[tv]
|
|
|
|
if tv.is_derived:
|
|
return TypeVar.derived(subst(tv.base, tv_map), tv.derived_func)
|
|
|
|
return tv
|
|
|
|
|
|
def normalize_tv(tv):
|
|
# type: (TypeVar) -> TypeVar
|
|
"""
|
|
Normalize a (potentially derived) TV using the following rules:
|
|
- vector and width derived functions commute
|
|
{HALF,DOUBLE}VECTOR({HALF,DOUBLE}WIDTH(base)) ->
|
|
{HALF,DOUBLE}WIDTH({HALF,DOUBLE}VECTOR(base))
|
|
|
|
- half/double pairs collapse
|
|
{HALF,DOUBLE}WIDTH({DOUBLE,HALF}WIDTH(base)) -> base
|
|
{HALF,DOUBLE}VECTOR({DOUBLE,HALF}VECTOR(base)) -> base
|
|
"""
|
|
vector_derives = [TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR]
|
|
width_derives = [TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH]
|
|
|
|
if not tv.is_derived:
|
|
return tv
|
|
|
|
df = tv.derived_func
|
|
|
|
if (tv.base.is_derived):
|
|
base_df = tv.base.derived_func
|
|
|
|
# Reordering: {HALFWIDTH, DOUBLEWIDTH} commute with {HALFVECTOR,
|
|
# DOUBLEVECTOR}. Arbitrarily pick WIDTH < VECTOR
|
|
if df in vector_derives and base_df in width_derives:
|
|
return normalize_tv(
|
|
TypeVar.derived(
|
|
TypeVar.derived(tv.base.base, df), base_df))
|
|
|
|
# Cancelling: HALFWIDTH, DOUBLEWIDTH and HALFVECTOR, DOUBLEVECTOR
|
|
# cancel each other. Note: This doesn't hide any over/underflows,
|
|
# since we 1) assert the safety of each TV in the chain upon its
|
|
# creation, and 2) the base typeset is only allowed to shrink.
|
|
|
|
if (df, base_df) in \
|
|
[(TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR),
|
|
(TypeVar.DOUBLEVECTOR, TypeVar.HALFVECTOR),
|
|
(TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH),
|
|
(TypeVar.DOUBLEWIDTH, TypeVar.HALFWIDTH)]:
|
|
return normalize_tv(tv.base.base)
|
|
|
|
return TypeVar.derived(normalize_tv(tv.base), df)
|
|
|
|
|
|
def constrain_fixpoint(tv1, tv2):
|
|
# type: (TypeVar, TypeVar) -> None
|
|
"""
|
|
Given typevars tv1 and tv2 (which could be derived from one another)
|
|
constrain their typesets to be the same. When one is derived from the
|
|
other, repeat the constrain process until fixpoint.
|
|
"""
|
|
# Constrain tv2's typeset as long as tv1's typeset is changing.
|
|
while True:
|
|
old_tv1_ts = tv1.get_typeset().copy()
|
|
tv2.constrain_types(tv1)
|
|
if tv1.get_typeset() == old_tv1_ts:
|
|
break
|
|
|
|
old_tv2_ts = tv2.get_typeset().copy()
|
|
tv1.constrain_types(tv2)
|
|
assert old_tv2_ts == tv2.get_typeset()
|
|
|
|
|
|
def unify(tv1, tv2, typ):
|
|
# type: (TypeVar, TypeVar, TypeEnv) -> TypingOrError
|
|
"""
|
|
Unify tv1 and tv2 in the current type environment typ, and return an
|
|
updated type environment or error.
|
|
"""
|
|
tv1 = normalize_tv(typ[tv1])
|
|
tv2 = normalize_tv(typ[tv2])
|
|
|
|
# Already unified
|
|
if tv1 == tv2:
|
|
return typ
|
|
|
|
if typ.rank(tv2) < typ.rank(tv1):
|
|
return unify(tv2, tv1, typ)
|
|
|
|
constrain_fixpoint(tv1, tv2)
|
|
|
|
if (tv1.get_typeset().size() == 0 or tv2.get_typeset().size() == 0):
|
|
return "Error: empty type created when unifying {} and {}"\
|
|
.format(tv1, tv2)
|
|
|
|
# Free -> Derived(Free)
|
|
if not tv1.is_derived:
|
|
typ.equivalent(tv1, tv2)
|
|
return typ
|
|
|
|
if (tv1.is_derived and TypeVar.is_bijection(tv1.derived_func)):
|
|
inv_f = TypeVar.inverse_func(tv1.derived_func)
|
|
return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ)
|
|
|
|
typ.add_constraint(TypesEqual(tv1, tv2))
|
|
return typ
|
|
|
|
|
|
def move_first(l, i):
|
|
# type: (List[T], int) -> List[T]
|
|
return [l[i]] + l[:i] + l[i+1:]
|
|
|
|
|
|
def ti_def(definition, typ):
|
|
# type: (Def, TypeEnv) -> TypingOrError
|
|
"""
|
|
Perform type inference on one Def in the current type environment typ and
|
|
return an updated type environment or error.
|
|
|
|
At a high level this works by creating fresh copies of each formal type var
|
|
in the Def's instruction's signature, and unifying the formal tv with the
|
|
corresponding actual tv.
|
|
"""
|
|
expr = definition.expr
|
|
inst = expr.inst
|
|
|
|
# Create a dict m mapping each free typevar in the signature of definition
|
|
# to a fresh copy of itself.
|
|
free_formal_tvs = inst.all_typevars()
|
|
m = {tv: tv.get_fresh_copy(str(typ.get_uid())) for tv in free_formal_tvs}
|
|
|
|
# Update m with any explicitly bound type vars
|
|
for (idx, bound_typ) in enumerate(expr.typevars):
|
|
m[free_formal_tvs[idx]] = TypeVar.singleton(bound_typ)
|
|
|
|
# Get fresh copies for each typevar in the signature (both free and
|
|
# derived)
|
|
fresh_formal_tvs = \
|
|
[subst(inst.outs[i].typevar, m) for i in inst.value_results] +\
|
|
[subst(inst.ins[i].typevar, m) for i in inst.value_opnums]
|
|
|
|
# Get the list of actual Vars
|
|
actual_vars = [] # type: List[Expr]
|
|
actual_vars += [definition.defs[i] for i in inst.value_results]
|
|
actual_vars += [expr.args[i] for i in inst.value_opnums]
|
|
|
|
# Get the list of the actual TypeVars
|
|
actual_tvs = []
|
|
for v in actual_vars:
|
|
assert(isinstance(v, Var))
|
|
# Register with TypeEnv that this typevar corresponds ot variable v,
|
|
# and thus has a given rank
|
|
typ.register(v)
|
|
actual_tvs.append(v.get_typevar())
|
|
|
|
# Make sure we unify the control typevar first.
|
|
if inst.is_polymorphic:
|
|
idx = fresh_formal_tvs.index(m[inst.ctrl_typevar])
|
|
fresh_formal_tvs = move_first(fresh_formal_tvs, idx)
|
|
actual_tvs = move_first(actual_tvs, idx)
|
|
|
|
# Unify each actual typevar with the corresponding fresh formal tv
|
|
for (actual_tv, formal_tv) in zip(actual_tvs, fresh_formal_tvs):
|
|
typ_or_err = unify(actual_tv, formal_tv, typ)
|
|
err = get_error(typ_or_err)
|
|
if (err):
|
|
return "fail ti on {} <: {}: ".format(actual_tv, formal_tv) + err
|
|
|
|
typ = get_type_env(typ_or_err)
|
|
|
|
# Add any instruction specific constraints
|
|
for constr in inst.constraints:
|
|
typ.add_constraint(constr.translate(m))
|
|
|
|
return typ
|
|
|
|
|
|
def ti_rtl(rtl, typ):
|
|
# type: (Rtl, TypeEnv) -> TypingOrError
|
|
"""
|
|
Perform type inference on an Rtl in a starting type env typ. Return an
|
|
updated type environment or error.
|
|
"""
|
|
for (i, d) in enumerate(rtl.rtl):
|
|
assert (isinstance(d, Def))
|
|
typ_or_err = ti_def(d, typ)
|
|
err = get_error(typ_or_err) # type: Optional[TypingError]
|
|
if (err):
|
|
return "On line {}: ".format(i) + err
|
|
|
|
typ = get_type_env(typ_or_err)
|
|
|
|
return typ
|
|
|
|
|
|
def ti_xform(xform, typ):
|
|
# type: (XForm, TypeEnv) -> TypingOrError
|
|
"""
|
|
Perform type inference on an Rtl in a starting type env typ. Return an
|
|
updated type environment or error.
|
|
"""
|
|
typ_or_err = ti_rtl(xform.src, typ)
|
|
err = get_error(typ_or_err) # type: Optional[TypingError]
|
|
if (err):
|
|
return "In src pattern: " + err
|
|
|
|
typ = get_type_env(typ_or_err)
|
|
|
|
typ_or_err = ti_rtl(xform.dst, typ)
|
|
err = get_error(typ_or_err)
|
|
if (err):
|
|
return "In dst pattern: " + err
|
|
|
|
typ = get_type_env(typ_or_err)
|
|
|
|
return get_type_env(typ_or_err)
|