Skip to content

Commit

Permalink
Support scala_artifact (pantsbuild#19128)
Browse files Browse the repository at this point in the history
  • Loading branch information
alonsodomin authored May 26, 2023
1 parent 840608b commit b455834
Show file tree
Hide file tree
Showing 8 changed files with 589 additions and 7 deletions.
13 changes: 13 additions & 0 deletions docs/markdown/Java and Scala/jvm-overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,19 @@ jvm_artifact(
)
```

If your third party dependency is a Scala library, you should use the `scala_artifact` target instead like follows:

```python BUILD
scala_artifact(
group="org.typelevel",
artifact="cats-core",
version="2.9.0",
packages=["cats.**"],
)
```

Pants will use the right artifact for the Scala version corresponding for the resolve specified (or the default one).

Pants requires use of a lockfile for thirdparty dependencies. After adding or editing `jvm_artifact` targets, you will need to update affected lockfiles by running `pants generate-lockfiles`. The default lockfile is located at `3rdparty/jvm/default.lock`, but it can be relocated (as well as additional resolves declared) via the [`[jvm].resolves` option](doc:reference-jvm#section-resolves).

> 📘 Thirdparty symbols and the `packages` argument
Expand Down
2 changes: 1 addition & 1 deletion pants.toml
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ args = ["--external-sources"]
args = ["-i 2", "-ci", "-sr"]

[pytest]
args = ["--no-header"]
args = ["--no-header", "-vv"]
execution_slot_var = "TEST_EXECUTION_SLOT"
install_from_resolve = "pytest"
requirements = ["//3rdparty/python:pytest"]
Expand Down
5 changes: 4 additions & 1 deletion src/python/pants/backend/experimental/scala/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pants.backend.scala.goals import check, repl, tailor
from pants.backend.scala.resolve.lockfile import rules as scala_lockfile_rules
from pants.backend.scala.target_types import (
ScalaArtifactTarget,
ScalacPluginTarget,
ScalaJunitTestsGeneratorTarget,
ScalaJunitTestTarget,
Expand All @@ -15,6 +16,7 @@
ScalatestTestsGeneratorTarget,
ScalatestTestTarget,
)
from pants.backend.scala.target_types import build_file_aliases as scala_build_file_aliases
from pants.backend.scala.target_types import rules as target_types_rules
from pants.backend.scala.test import scalatest
from pants.core.util_rules.wrap_source import wrap_source_rule_and_target
Expand All @@ -32,6 +34,7 @@ def target_types():
ScalacPluginTarget,
ScalatestTestTarget,
ScalatestTestsGeneratorTarget,
ScalaArtifactTarget,
*jvm_common.target_types(),
*wrap_scala.target_types,
]
Expand All @@ -54,4 +57,4 @@ def rules():


def build_file_aliases():
return jvm_common.build_file_aliases()
return jvm_common.build_file_aliases().merge(scala_build_file_aliases())
4 changes: 4 additions & 0 deletions src/python/pants/backend/scala/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,7 @@
# Licensed under the Apache License, Version 2.0 (see LICENSE).

python_sources()

python_tests(
name="tests",
)
203 changes: 200 additions & 3 deletions src/python/pants/backend/scala/target_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,21 @@
from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from typing import ClassVar

from pants.backend.scala.subsystems.scala import ScalaSubsystem
from pants.backend.scala.subsystems.scala_infer import ScalaInferSubsystem
from pants.build_graph.build_file_aliases import BuildFileAliases
from pants.core.goals.test import TestExtraEnvVarsField, TestTimeoutField
from pants.engine.rules import collect_rules, rule
from pants.engine.target import (
COMMON_TARGET_FIELDS,
AsyncFieldMixin,
Dependencies,
FieldSet,
GeneratedTargets,
GenerateTargetsRequest,
MultipleSourcesField,
OverridesField,
SingleSourceField,
Expand All @@ -22,22 +28,35 @@
TargetFilesGenerator,
TargetFilesGeneratorSettings,
TargetFilesGeneratorSettingsRequest,
TargetGenerator,
generate_file_based_overrides_field_help_message,
generate_multiple_sources_field_help_message,
)
from pants.engine.unions import UnionRule
from pants.engine.unions import UnionMembership, UnionRule
from pants.jvm import target_types as jvm_target_types
from pants.jvm.subsystems import JvmSubsystem
from pants.jvm.target_types import (
JunitTestExtraEnvVarsField,
JunitTestSourceField,
JunitTestTimeoutField,
JvmArtifactArtifactField,
JvmArtifactExclusion,
JvmArtifactExclusionsField,
JvmArtifactGroupField,
JvmArtifactJarSourceField,
JvmArtifactPackagesField,
JvmArtifactResolveField,
JvmArtifactTarget,
JvmArtifactUrlField,
JvmArtifactVersionField,
JvmJdkField,
JvmMainClassNameField,
JvmProvidesTypesField,
JvmResolveField,
JvmRunnableSourceFieldSet,
_jvm_artifact_exclusions_field_help,
)
from pants.util.strutil import help_text
from pants.util.strutil import help_text, softwrap


class ScalaSettingsRequest(TargetFilesGeneratorSettingsRequest):
Expand Down Expand Up @@ -324,7 +343,7 @@ class ScalacPluginArtifactField(StringField, AsyncFieldMixin):
alias = "artifact"
required = True
value: str
help = "The address of a `jvm_artifact` that defines a plugin for `scalac`."
help = "The address of either a `jvm_artifact` or a `scala_artifact` that defines a plugin for `scalac`."


class ScalacPluginNameField(StringField):
Expand Down Expand Up @@ -359,10 +378,188 @@ class ScalacPluginTarget(Target):
)


# -----------------------------------------------------------------------------------------------
# `scala_artifact` target
# -----------------------------------------------------------------------------------------------


class ScalaCrossVersion(Enum):
PARTIAL = "partial"
FULL = "full"


class ScalaArtifactCrossversionField(StringField):
alias = "crossversion"
default = ScalaCrossVersion.PARTIAL.value
help = help_text(
"""
Whether to use the full Scala version or the partial one to determine the artifact name suffix.
Default is `partial`.
"""
)
valid_choices = ScalaCrossVersion


@dataclass(frozen=True)
class ScalaArtifactExclusion(JvmArtifactExclusion):
alias = "scala_exclude"
help = help_text(
"""
Exclude the given `artifact` and `group`, or all artifacts from the given `group`.
You can also use the `crossversion` field to help resolve the final artifact name.
"""
)

crossversion: str = ScalaCrossVersion.PARTIAL.value

def validate(self) -> set[str]:
errors = super().validate()
valid_crossversions = [x.value for x in ScalaCrossVersion]
if self.crossversion not in valid_crossversions:
errors.add(
softwrap(
f"""
Invalid `crossversion` value: {self.crossversion}. Valid values are:
{', '.join(valid_crossversions)}
"""
)
)
return errors


class ScalaArtifactExclusionsField(JvmArtifactExclusionsField):
help = _jvm_artifact_exclusions_field_help(
lambda: ScalaArtifactExclusionsField.supported_rule_types
)
supported_rule_types: ClassVar[tuple[type[JvmArtifactExclusion], ...]] = (
JvmArtifactExclusion,
ScalaArtifactExclusion,
)


@dataclass(frozen=True)
class ScalaArtifactFieldSet(FieldSet):
group: JvmArtifactGroupField
artifact: JvmArtifactArtifactField
version: JvmArtifactVersionField
packages: JvmArtifactPackagesField
exclusions: ScalaArtifactExclusionsField
crossversion: ScalaArtifactCrossversionField

required_fields = (
JvmArtifactGroupField,
JvmArtifactArtifactField,
JvmArtifactVersionField,
JvmArtifactPackagesField,
ScalaArtifactCrossversionField,
)


class ScalaArtifactTarget(TargetGenerator):
alias = "scala_artifact"
help = help_text(
"""
A third-party Scala artifact, as identified by its Maven-compatible coordinate.
That is, an artifact identified by its `group`, `artifact`, and `version` components.
Each artifact is associated with one or more resolves (a logical name you give to a
lockfile). For this artifact to be used by your first-party code, it must be
associated with the resolve(s) used by that code. See the `resolve` field.
Being a Scala artifact, the final artifact name will be inferred using the Scala version
configured for the given resolve.
"""
)
core_fields = (
*COMMON_TARGET_FIELDS,
*ScalaArtifactFieldSet.required_fields,
ScalaArtifactExclusionsField,
JvmArtifactUrlField,
JvmArtifactJarSourceField,
JvmMainClassNameField,
)
copied_fields = (
*COMMON_TARGET_FIELDS,
JvmArtifactGroupField,
JvmArtifactVersionField,
JvmArtifactPackagesField,
JvmArtifactUrlField,
JvmArtifactJarSourceField,
JvmMainClassNameField,
)
moved_fields = (
JvmArtifactResolveField,
JvmJdkField,
)


class GenerateJvmArtifactForScalaTargets(GenerateTargetsRequest):
generate_from = ScalaArtifactTarget


@rule
async def generate_jvm_artifact_targets(
request: GenerateJvmArtifactForScalaTargets,
jvm: JvmSubsystem,
scala: ScalaSubsystem,
union_membership: UnionMembership,
) -> GeneratedTargets:
field_set = ScalaArtifactFieldSet.create(request.generator)
resolve_name = request.template.get(JvmArtifactResolveField.alias) or jvm.default_resolve
scala_version = scala.version_for_resolve(resolve_name)
scala_version_parts = scala_version.split(".")

def scala_suffix(crossversion: ScalaCrossVersion) -> str:
if crossversion == ScalaCrossVersion.FULL:
return scala_version
elif int(scala_version_parts[0]) >= 3:
return scala_version_parts[0]

return f"{scala_version_parts[0]}.{scala_version_parts[1]}"

exclusions_field = {}
if field_set.exclusions.value:
exclusions = []
for exclusion in field_set.exclusions.value:
if not isinstance(exclusion, ScalaArtifactExclusion):
exclusions.append(exclusion)
else:
excluded_artifact_name = None
if exclusion.artifact:
crossversion = ScalaCrossVersion(exclusion.crossversion)
excluded_artifact_name = f"{exclusion.artifact}_{scala_suffix(crossversion)}"
exclusions.append(
JvmArtifactExclusion(group=exclusion.group, artifact=excluded_artifact_name)
)
exclusions_field[JvmArtifactExclusionsField.alias] = exclusions

crossversion = ScalaCrossVersion(field_set.crossversion.value)
artifact_name = f"{field_set.artifact.value}_{scala_suffix(crossversion)}"
jvm_artifact_target = JvmArtifactTarget(
{
**request.template,
JvmArtifactArtifactField.alias: artifact_name,
**exclusions_field,
},
request.generator.address.create_generated(artifact_name),
union_membership,
residence_dir=request.generator.address.spec_path,
)

return GeneratedTargets(request.generator, (jvm_artifact_target,))


def rules():
return (
*collect_rules(),
*jvm_target_types.rules(),
*ScalaFieldSet.jvm_rules(),
UnionRule(TargetFilesGeneratorSettingsRequest, ScalaSettingsRequest),
UnionRule(GenerateTargetsRequest, GenerateJvmArtifactForScalaTargets),
)


def build_file_aliases():
return BuildFileAliases(objects={ScalaArtifactExclusion.alias: ScalaArtifactExclusion})
Loading

0 comments on commit b455834

Please sign in to comment.