Skip to content

Commit

Permalink
Allow the upgrade_package flag to accept a version number along with …
Browse files Browse the repository at this point in the history
…a package name
  • Loading branch information
nnja committed Nov 15, 2016
1 parent 89ddb61 commit 92346a0
Showing 1 changed file with 31 additions and 19 deletions.
50 changes: 31 additions & 19 deletions piptools/scripts/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class PipCommand(pip.basecommand.Command):
help="Annotate results, indicating where dependencies come from")
@click.option('-U', '--upgrade', is_flag=True, default=False,
help='Try to upgrade all dependencies to their latest versions')
@click.option('-P', '--upgrade-package', nargs=1, multiple=True, help="Specify particular packages to upgrade.")
@click.option('-P', '--upgrade-package', nargs=1, multiple=True,
help="Specify particular packages to upgrade.")
@click.option('-o', '--output-file', nargs=1, type=str, default=None,
help=('Output file name. Required if more than one input file is given. '
'Will be derived from input file otherwise.'))
Expand All @@ -72,21 +73,25 @@ def cli(verbose, dry_run, pre, rebuild, find_links, index_url, extra_index_url,
if len(src_files) == 1 and src_files[0] == '-':
if not output_file:
raise click.BadParameter('--output-file is required if input is from stdin')
if upgrade_package:
raise click.BadParameter('upgrade_package does not support input from stdin.')

if len(src_files) > 1 and not output_file:
raise click.BadParameter('--output-file is required if two or more input files are given.')

if upgrade and upgrade_package:
raise click.BadParameter('Only one of --upgrade or --upgrade_package can be provided as an argument.')

if output_file:
dst_file = output_file
else:
base_name, _, _ = src_files[0].rpartition('.')
dst_file = base_name + '.txt'

if upgrade and upgrade_package:
raise click.BadParameter('Only one of --upgrade or --upgrade-package can be provided as an argument.')

# Process the arguments to upgrade-package into name and version pairs.
upgrade_package_reqs = [
package.split('==') if '==' in package else (package, None)
for package in upgrade_package
]

###
# Setup
###
Expand Down Expand Up @@ -122,22 +127,24 @@ def cli(verbose, dry_run, pre, rebuild, find_links, index_url, extra_index_url,
pip_options, _ = pip_command.parse_args(pip_args)

session = pip_command._build_session(pip_options)

repository = live_repository = PyPIRepository(pip_options, session)
local_repository = None
use_local_repo = not upgrade and os.path.exists(dst_file)

# Proxy with a LocalRequirementsRepository if --upgrade is not specified
# (= default invocation)
if not upgrade and os.path.exists(dst_file):
if use_local_repo:
existing_pins = dict()
ireqs = parse_requirements(dst_file, finder=repository.finder, session=repository.session, options=pip_options)
for ireq in ireqs:
if is_pinned_requirement(ireq):
existing_pins[name_from_req(ireq.req).lower()] = ireq

if upgrade_package:
for package in upgrade_package:
if package not in existing_pins:
log.error("Asked to upgrade %s but it's not already pinned. Quitting..." % package)
sys.exit(2)
for package, _ in upgrade_package_reqs:
if package not in existing_pins:
log.error("Asked to upgrade %s but it's not present in existing requirements. Quitting..." % package)
sys.exit(2)

repository = local_repository = LocalRequirementsRepository(existing_pins, repository)

Expand All @@ -156,7 +163,7 @@ def cli(verbose, dry_run, pre, rebuild, find_links, index_url, extra_index_url,
###

constraints = []
repository = live_repository if upgrade else local_repository
repository = local_repository if use_local_repo else live_repository
for src_file in src_files:
if src_file == '-':
# pip requires filenames and not files. Since we want to support
Expand All @@ -177,8 +184,10 @@ def cli(verbose, dry_run, pre, rebuild, find_links, index_url, extra_index_url,
upgraded_requirements = {}

# pip requires filenames, not files.
with tempfile.NamedTemporaryFile() as tmpfile:
tmpfile.write(str.encode('\n'.join(upgrade_package)))
with tempfile.NamedTemporaryFile(mode='wt') as tmpfile:
for package, version in upgrade_package_reqs:
line = '{}\n'.format(package) if not version else '{}=={}\n'.format(package, version)
tmpfile.write(line)
tmpfile.flush()

upgrade_candidates = list(
Expand All @@ -191,16 +200,19 @@ def cli(verbose, dry_run, pre, rebuild, find_links, index_url, extra_index_url,
.format(candidate.req.key))
sys.exit(2)

# we only want to process upgrades if the requirement is not pinned in the source file
constraint_candidate = existing_constraints[candidate.req.key]
if not constraint_candidate.req.specifier:
upgraded_requirements[candidate.req.key] = constraint_candidate
if constraint_candidate.req.specifier:
log.error("Asked to upgrade {} but it's pinned to a version in the source file. Quitting..."
.format(candidate.req.key))
sys.exit(2)
else:
upgraded_requirements[candidate.req.key] = candidate

existing_requirements.update(upgraded_requirements)
constraints = existing_requirements.values()

try:
repository = live_repository if upgrade or upgrade_package else local_repository
repository = local_repository if use_local_repo and not upgrade_package else live_repository
resolver = Resolver(constraints, repository, prereleases=pre, clear_caches=rebuild)
results = resolver.resolve()
except PipToolsError as e:
Expand Down

0 comments on commit 92346a0

Please sign in to comment.