Skip to content

Commit

Permalink
fix Dist class handling of .conda files
Browse files Browse the repository at this point in the history
  • Loading branch information
msarahan committed Jun 24, 2019
1 parent 4b2683e commit 60260e7
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 162 deletions.
3 changes: 3 additions & 0 deletions conda/exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@
from .models.version import VersionOrder, normalized_version # NOQA
VersionOrder, normalized_version = VersionOrder, normalized_version # NOQA

from .models.channel import Channel # NOQA
Channel = Channel # NOQA

import conda.base.context # NOQA
from .base.context import get_prefix, non_x86_linux_machines, reset_context, sys_rc_path # NOQA
non_x86_linux_machines, sys_rc_path = non_x86_linux_machines, sys_rc_path
Expand Down
65 changes: 37 additions & 28 deletions conda/models/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@
from .records import PackageRecord
from .. import CondaError
from .._vendor.auxlib.entity import Entity, EntityType, IntegerField, StringField
from ..base.constants import (CONDA_PACKAGE_EXTENSION_V1, CONDA_PACKAGE_EXTENSIONS,
DEFAULTS_CHANNEL_NAME, UNKNOWN_CHANNEL)
from ..base.constants import CONDA_PACKAGE_EXTENSIONS, DEFAULTS_CHANNEL_NAME, UNKNOWN_CHANNEL
from ..base.context import context
from ..common.compat import ensure_text_type, text_type, with_metaclass
from ..common.constants import NULL
from ..common.url import has_platform, is_url, join_url

log = getLogger(__name__)
DistDetails = namedtuple('DistDetails', ('name', 'version', 'build_string', 'build_number',
'dist_name'))
'dist_name', 'fmt'))


IndexRecord = PackageRecord # for conda-build backward compat
Expand Down Expand Up @@ -53,6 +52,18 @@ def __call__(cls, *args, **kwargs):
return super(DistType, cls).__call__(*args, **kwargs)


def strip_extension(original_dist):
for ext in CONDA_PACKAGE_EXTENSIONS:
if original_dist.endswith(ext):
original_dist = original_dist[:-len(ext)]
return original_dist


def split_extension(original_dist):
stripped = strip_extension(original_dist)
return stripped, original_dist[len(stripped):]


@with_metaclass(DistType)
class Dist(Entity):
_cache_ = {}
Expand All @@ -62,6 +73,7 @@ class Dist(Entity):

dist_name = StringField(immutable=True)
name = StringField(immutable=True)
fmt = StringField(immutable=True)
version = StringField(immutable=True)
build_string = StringField(immutable=True)
build_number = IntegerField(immutable=True)
Expand All @@ -70,15 +82,16 @@ class Dist(Entity):
platform = StringField(required=False, nullable=True, immutable=True)

def __init__(self, channel, dist_name=None, name=None, version=None, build_string=None,
build_number=None, base_url=None, platform=None):
build_number=None, base_url=None, platform=None, fmt='.tar.bz2'):
super(Dist, self).__init__(channel=channel,
dist_name=dist_name,
name=name,
version=version,
build_string=build_string,
build_number=build_number,
base_url=base_url,
platform=platform)
platform=platform,
fmt=fmt)

def to_package_ref(self):
return PackageRecord(
Expand Down Expand Up @@ -123,11 +136,11 @@ def is_feature_package(self):
def is_channel(self):
return bool(self.base_url and self.platform)

def to_filename(self, extension=CONDA_PACKAGE_EXTENSION_V1):
def to_filename(self, extension=None):
if self.is_feature_package:
return self.dist_name
else:
return self.dist_name + extension
return self.dist_name + self.fmt

def to_matchspec(self):
return ' '.join(self.quad[:3])
Expand Down Expand Up @@ -158,9 +171,7 @@ def from_string(cls, string, channel_override=NULL):
)
channel, original_dist, w_f_d = re.search(REGEX_STR, string).groups()

for ext in CONDA_PACKAGE_EXTENSIONS:
if original_dist.endswith(ext):
original_dist = original_dist[:-len(ext)]
original_dist, fmt = split_extension(original_dist)

if channel_override != NULL:
channel = channel_override
Expand All @@ -174,23 +185,21 @@ def from_string(cls, string, channel_override=NULL):
version=dist_details.version,
build_string=dist_details.build_string,
build_number=dist_details.build_number,
dist_name=original_dist)
dist_name=original_dist,
fmt=fmt)

@staticmethod
def parse_dist_name(string):
original_string = string
try:
string = ensure_text_type(string)

no_tar_bz2_string = (string[:-len(CONDA_PACKAGE_EXTENSION_V1)]
if string.endswith(CONDA_PACKAGE_EXTENSION_V1)
else string)
no_fmt_string, fmt = split_extension(string)

# remove any directory or channel information
if '::' in no_tar_bz2_string:
dist_name = no_tar_bz2_string.rsplit('::', 1)[-1]
if '::' in no_fmt_string:
dist_name = no_fmt_string.rsplit('::', 1)[-1]
else:
dist_name = no_tar_bz2_string.rsplit('/', 1)[-1]
dist_name = no_fmt_string.rsplit('/', 1)[-1]

parts = dist_name.rsplit('-', 2)

Expand All @@ -202,15 +211,15 @@ def parse_dist_name(string):
if build_string else '0')))
build_number = int(build_number_as_string) if build_number_as_string else 0

return DistDetails(name, version, build_string, build_number, dist_name)
return DistDetails(name, version, build_string, build_number, dist_name, fmt)

except:
raise CondaError("dist_name is not a valid conda package: %s" % original_string)

@classmethod
def from_url(cls, url):
assert is_url(url), url
if not url.endswith(CONDA_PACKAGE_EXTENSION_V1) and '::' not in url:
if not any(url.endswith(ext) for ext in CONDA_PACKAGE_EXTENSIONS) and '::' not in url:
raise CondaError("url '%s' is not a conda package" % url)

dist_details = cls.parse_dist_name(url)
Expand All @@ -232,12 +241,13 @@ def from_url(cls, url):
build_number=dist_details.build_number,
dist_name=dist_details.dist_name,
base_url=base_url,
platform=platform)
platform=platform,
fmt=dist_details.fmt)

def to_url(self):
if not self.base_url:
return None
filename = self.dist_name + CONDA_PACKAGE_EXTENSION_V1
filename = self.dist_name + self.fmt
return (join_url(self.base_url, self.platform, filename)
if self.platform
else join_url(self.base_url, filename))
Expand All @@ -262,7 +272,9 @@ def __ge__(self, other):
return self.__key__() >= other.__key__()

def __hash__(self):
return hash(self.__key__())
# dists compare equal regardless of fmt, but fmt is taken into account for
# object identity
return hash((self.__key__(), self.fmt))

def __eq__(self, other):
return isinstance(other, self.__class__) and self.__key__() == other.__key__()
Expand All @@ -286,9 +298,7 @@ def startswith(self, match):
return self.dist_name.startswith(match)

def __contains__(self, item):
item = ensure_text_type(item)
if item.endswith(CONDA_PACKAGE_EXTENSION_V1):
item = item[:-len(CONDA_PACKAGE_EXTENSION_V1)]
item = strip_extension(ensure_text_type(item))
return item in self.__str__()

@property
Expand All @@ -297,8 +307,7 @@ def fn(self):


def dist_str_to_quad(dist_str):
if dist_str.endswith(CONDA_PACKAGE_EXTENSION_V1):
dist_str = dist_str[:-len(CONDA_PACKAGE_EXTENSION_V1)]
dist_str = strip_extension(dist_str)
if '::' in dist_str:
channel_str, dist_str = dist_str.split("::", 1)
else:
Expand Down
Loading

0 comments on commit 60260e7

Please sign in to comment.