Bug 1806899 - Detect link escapes in safe_extract (all m-c) r=jcristau

Apply the link escape check from tooltool to all the m-c tarfile extractions previously updated.

Differential Revision: https://phabricator.services.mozilla.com/D170660
This commit is contained in:
Geoff Brown 2023-02-24 16:26:27 +00:00
parent 7f7ab564b8
commit ab032de1aa
9 changed files with 103 additions and 52 deletions

View file

@ -139,19 +139,23 @@ def fetch_local(target, path, commit):
def validate_tar_member(member, path):
def _is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory
real_directory = os.path.realpath(directory)
real_target = os.path.realpath(target)
prefix = os.path.commonprefix([real_directory, real_target])
return prefix == real_directory
member_path = os.path.join(path, member.name)
if not _is_within_directory(path, member_path):
raise Exception("Attempted path traversal in tar file: " + member.name)
if member.issym():
link_path = os.path.join(os.path.dirname(member_path), member.linkname)
if not _is_within_directory(path, link_path):
raise Exception("Attempted link path traversal in tar file: " + member.name)
if member.mode & (stat.S_ISUID | stat.S_ISGID):
raise Exception("Attempted setuid or setgid in tar file: " + member.name)
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
def safe_extract(tar, path=".", *, numeric_owner=False):
def _files(tar, path):
for member in tar:
validate_tar_member(member, path)

View file

@ -958,19 +958,23 @@ CHECKSUM_SUFFIX = ".checksum"
def validate_tar_member(member, path):
def _is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory
real_directory = os.path.realpath(directory)
real_target = os.path.realpath(target)
prefix = os.path.commonprefix([real_directory, real_target])
return prefix == real_directory
member_path = os.path.join(path, member.name)
if not _is_within_directory(path, member_path):
raise Exception("Attempted path traversal in tar file: " + member.name)
if member.issym():
link_path = os.path.join(os.path.dirname(member_path), member.linkname)
if not _is_within_directory(path, link_path):
raise Exception("Attempted link path traversal in tar file: " + member.name)
if member.mode & (stat.S_ISUID | stat.S_ISGID):
raise Exception("Attempted setuid or setgid in tar file: " + member.name)
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
def safe_extract(tar, path=".", *, numeric_owner=False):
def _files(tar, path):
for member in tar:
validate_tar_member(member, path)

View file

@ -366,11 +366,34 @@ class VendorManifest(MozbuildObject):
def fetch_and_unpack(self, revision):
"""Fetch and unpack upstream source"""
def _is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory
def validate_tar_member(member, path):
def is_within_directory(directory, target):
real_directory = os.path.realpath(directory)
real_target = os.path.realpath(target)
prefix = os.path.commonprefix([real_directory, real_target])
return prefix == real_directory
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception("Attempted path traversal in tar file: " + member.name)
if member.issym():
link_path = os.path.join(os.path.dirname(member_path), member.linkname)
if not is_within_directory(path, link_path):
raise Exception(
"Attempted link path traversal in tar file: " + member.name
)
if member.mode & (stat.S_ISUID | stat.S_ISGID):
raise Exception(
"Attempted setuid or setgid in tar file: " + member.name
)
def safe_extract(tar, path=".", *, numeric_owner=False):
def _files(tar, path):
for member in tar:
validate_tar_member(member, path)
yield member
tar.extractall(path, members=_files(tar, path), numeric_owner=numeric_owner)
url = self.source_host.upstream_snapshot(revision)
self.logInfo({"url": url}, "Fetching code archive from {url}")
@ -383,20 +406,6 @@ class VendorManifest(MozbuildObject):
tmptarfile.write(data)
tmptarfile.seek(0)
tar = tarfile.open(tmptarfile.name)
for member in tar:
member_path = os.path.join(tmpextractdir.name, member.name)
if not _is_within_directory(tmpextractdir.name, member_path):
raise Exception(
"Tar archive contains non-local paths, e.g. '%s'"
% member.name
)
if member.mode & (stat.S_ISUID | stat.S_ISGID):
raise Exception(
"Tar archive has setuid or setgid member '%s'" % member.name
)
vendor_dir = mozpath.normsep(
self.manifest["vendoring"]["vendor-directory"]
)
@ -420,17 +429,18 @@ class VendorManifest(MozbuildObject):
mozfile.remove(file)
self.logInfo({"vd": vendor_dir}, "Unpacking upstream files for {vd}.")
tar.extractall(tmpextractdir.name)
with tarfile.open(tmptarfile.name) as tar:
def get_first_dir(p):
halves = os.path.split(p)
return get_first_dir(halves[0]) if halves[0] else halves[1]
safe_extract(tar, tmpextractdir.name)
one_prefix = get_first_dir(tar.getnames()[0])
has_prefix = all(
map(lambda name: name.startswith(one_prefix), tar.getnames())
)
tar.close()
def get_first_dir(p):
halves = os.path.split(p)
return get_first_dir(halves[0]) if halves[0] else halves[1]
one_prefix = get_first_dir(tar.getnames()[0])
has_prefix = all(
map(lambda name: name.startswith(one_prefix), tar.getnames())
)
# GitLab puts everything down a directory; move it up.
if has_prefix:

View file

@ -42,10 +42,10 @@ def extract_tarball(src, dest, ignore=None):
import tarfile
def _is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory
real_directory = os.path.realpath(directory)
real_target = os.path.realpath(target)
prefix = os.path.commonprefix([real_directory, real_target])
return prefix == real_directory
with tarfile.open(src) as bundle:
namelist = []
@ -65,6 +65,20 @@ def extract_tarball(src, dest, ignore=None):
"""
)
)
if m.issym():
link_path = os.path.join(os.path.dirname(member_path), m.linkname)
if not _is_within_directory(dest, link_path):
raise RuntimeError(
dedent(
f"""
Tar bundle '{src}' may be maliciously crafted to escape the destination!
The following path was detected:
{m.name}
"""
)
)
if m.mode & (stat.S_ISUID | stat.S_ISGID):
raise RuntimeError(
dedent(

View file

@ -93,22 +93,29 @@ class ContentLengthMismatch(Exception):
def _validate_tar_member(member, path):
def _is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory
real_directory = os.path.realpath(directory)
real_target = os.path.realpath(target)
prefix = os.path.commonprefix([real_directory, real_target])
return prefix == real_directory
member_path = os.path.join(path, member.name)
if not _is_within_directory(path, member_path):
raise Exception("Attempted path traversal in tar file: " + member.name)
if member.issym():
link_path = os.path.join(os.path.dirname(member_path), member.linkname)
if not _is_within_directory(path, link_path):
raise Exception("Attempted link path traversal in tar file: " + member.name)
if member.mode & (stat.S_ISUID | stat.S_ISGID):
raise Exception("Attempted setuid or setgid in tar file: " + member.name)
def _safe_extract(tar, path=".", members=None, *, numeric_owner=False):
for member in tar.getmembers():
_validate_tar_member(member, path)
tar.extractall(path, members, numeric_owner=numeric_owner)
def _safe_extract(tar, path=".", *, numeric_owner=False):
def _files(tar, path):
for member in tar:
_validate_tar_member(member, path)
yield member
tar.extractall(path, members=_files(tar, path), numeric_owner=numeric_owner)
def platform_name():

View file

@ -320,7 +320,13 @@ class TestScript(unittest.TestCase):
extract_to=self.tmpdir,
)
for archive in ("archive-setuid.tar", "archive-escape.tar"):
for archive in (
"archive-setuid.tar",
"archive-escape.tar",
"archive-link.tar",
"archive-link-abs.tar",
"archive-double-link.tar",
):
with self.assertRaises(Exception):
self.s.download_unpack(
url=os.path.join(archives_path, archive),
@ -371,7 +377,13 @@ class TestScript(unittest.TestCase):
self.tmpdir,
)
for archive in ("archive-setuid.tar", "archive-escape.tar"):
for archive in (
"archive-setuid.tar",
"archive-escape.tar",
"archive-link.tar",
"archive-link-abs.tar",
"archive-double-link.tar",
):
with self.assertRaises(Exception):
self.s.unpack(os.path.join(archives_path, archive), self.tmpdir)