Skip to content

Commit

Permalink
[#11246] MapView overrides transformations to return MapView when pos…
Browse files Browse the repository at this point in the history
…sible.
  • Loading branch information
joshlemer committed Feb 1, 2019
1 parent 439ed61 commit 09d43ff
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 9 deletions.
80 changes: 73 additions & 7 deletions src/library/scala/collection/MapView.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
package scala.collection


import scala.collection.MapView.SomeMapOps
import scala.collection.immutable.Map.Map1
import scala.collection.mutable.Builder

trait MapView[K, +V]
Expand All @@ -21,6 +23,12 @@ trait MapView[K, +V]

override def view: MapView[K, V] = this

def concat[V1 >: V](that: SomeMapOps[K, V1]): MapView[K, V1] = new MapView.Concat(this, that)

def ++[V1 >: V](that: SomeMapOps[K, V1]): MapView[K, V1] = concat(that)

override def +[V1 >: V](kv: (K, V1)): MapView[K, V1] = concat(new Map1(kv._1, kv._2))

/** Filters this map by retaining only keys satisfying a predicate.
* @param p the predicate used to test keys
* @return an immutable map consisting only of those key value pairs of this map where the key satisfies
Expand All @@ -35,18 +43,38 @@ trait MapView[K, +V]
*/
override def mapValues[W](f: V => W): MapView[K, W] = new MapView.MapValues(this, f)

def mapFactory: MapFactory[({ type l[X, Y] = View[(X, Y)] })#l] = new MapView.MapViewMapFactory[K, V]
override def filter(pred: ((K, V)) => Boolean): MapView[K, V] = new MapView.Filter(this, false, pred)

override def filterNot(pred: ((K, V)) => Boolean): MapView[K, V] = new MapView.Filter(this, true, pred)

def empty: View[(K, V)] = View.Empty
override def partition(p: ((K, V)) => Boolean): (MapView[K, V], MapView[K, V]) = (filter(p), filterNot(p))

def mapFactory: MapViewFactory = MapView

def empty: MapView[K, V] = mapFactory.empty
}

object MapView {
object MapView extends MapViewFactory {

/** An `IterableOps` whose collection type and collection type constructor are unknown */
type SomeIterableConstr[X, Y] = IterableOps[_, AnyConstr, _]
/** A `MapOps` whose collection type and collection type constructor are (mostly) unknown */
type SomeMapOps[K, +V] = MapOps[K, V, SomeIterableConstr, _]

@SerialVersionUID(3L)
private val EmptyMapView: MapView[Any, Nothing] = new AbstractMapView[Any, Nothing] {
override def get(key: Any): Option[Nothing] = None
override def iterator: Iterator[Nothing] = Iterator.empty[Nothing]
override def knownSize: Int = 0
override def isEmpty: Boolean = true
override def concat[V1 >: Nothing](that: SomeMapOps[Any, V1]): MapView[Any, V1] = mapFactory.from(that)
override def filterKeys(p: Any => Boolean): MapView[Any, Nothing] = this
override def mapValues[W](f: Nothing => W): MapView[Any, Nothing] = this
override def filter(pred: ((Any, Nothing)) => Boolean): MapView[Any, Nothing] = this
override def filterNot(pred: ((Any, Nothing)) => Boolean): MapView[Any, Nothing] = this
override def partition(p: ((Any, Nothing)) => Boolean): (MapView[Any, Nothing], MapView[Any, Nothing]) = (this, this)
}

@SerialVersionUID(3L)
class Id[K, +V](underlying: SomeMapOps[K, V]) extends AbstractMapView[K, V] {
def get(key: K): Option[V] = underlying.get(key)
Expand All @@ -55,6 +83,15 @@ object MapView {
override def isEmpty: Boolean = underlying.isEmpty
}

@SerialVersionUID(3L)
class Concat[K, +V](left: SomeMapOps[K, V], right: SomeMapOps[K, V]) extends AbstractMapView[K, V] {
def get(key: K): Option[V] = right.get(key) match {
case s @ Some(_) => s
case _ => left.get(key)
}
def iterator: Iterator[(K, V)] = left.iterator.filter { case (k, _) => !right.contains(k) }.concat(right.iterator)
}

@SerialVersionUID(3L)
class MapValues[K, +V, +W](underlying: SomeMapOps[K, V], f: V => W) extends AbstractMapView[K, W] {
def iterator: Iterator[(K, W)] = underlying.iterator.map(kv => (kv._1, f(kv._2)))
Expand All @@ -72,13 +109,42 @@ object MapView {
}

@SerialVersionUID(3L)
private class MapViewMapFactory[K, V] extends MapFactory[({ type l[X, Y] = View[(X, Y)] })#l] {
def newBuilder[X, Y]: Builder[(X, Y), View[(X, Y)]] = View.newBuilder[(X, Y)]
def empty[X, Y]: View[(X, Y)] = View.empty
def from[X, Y](it: IterableOnce[(X, Y)]): View[(X, Y)] = View.from(it)
class Filter[K, +V](underlying: SomeMapOps[K, V], isFlipped: Boolean, p: ((K, V)) => Boolean) extends AbstractMapView[K, V] {
def iterator: Iterator[(K, V)] = underlying.iterator.filterImpl(p, isFlipped)
def get(key: K): Option[V] = underlying.get(key) match {
case s @ Some(v) if p((key, v)) != isFlipped => s
case _ => None
}
override def knownSize: Int = if (underlying.knownSize == 0) 0 else super.knownSize
override def isEmpty: Boolean = iterator.isEmpty
}

override def newBuilder[X, Y]: Builder[(X, Y), MapView[X, Y]] = mutable.HashMap.newBuilder[X, Y].mapResult(_.view)

override def empty[K, V]: MapView[K, V] = EmptyMapView.asInstanceOf[MapView[K, V]]

override def from[K, V](it: IterableOnce[(K, V)]): View[(K, V)] = View.from(it)

override def from[K, V](it: SomeMapOps[K, V]): MapView[K, V] = it match {
case mv: MapView[K, V] => mv
case other => new MapView.Id(other)
}

override def apply[K, V](elems: (K, V)*): MapView[K, V] = from(elems.toMap)
}

trait MapViewFactory extends collection.MapFactory[({ type l[X, Y] = View[(X, Y)]})#l] {

def newBuilder[X, Y]: Builder[(X, Y), MapView[X, Y]]

def empty[X, Y]: MapView[X, Y]

def from[K, V](it: SomeMapOps[K, V]): MapView[K, V]

override def apply[K, V](elems: (K, V)*): MapView[K, V] = from(elems.toMap)
}

/** Explicit instantiation of the `MapView` trait to reduce class file size in subclasses. */
@SerialVersionUID(3L)
abstract class AbstractMapView[K, +V] extends AbstractView[(K, V)] with MapView[K, V]

2 changes: 0 additions & 2 deletions src/library/scala/collection/immutable/Map.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ package scala
package collection
package immutable

import java.io.{ObjectInputStream, ObjectOutputStream}

import scala.annotation.unchecked.uncheckedVariance
import scala.collection.immutable.Map.Map4
import scala.collection.mutable.{Builder, ReusableBuilder}
Expand Down
48 changes: 48 additions & 0 deletions test/scalacheck/scala/collection/MapViewProperties.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package scala.collection

import org.scalacheck.Arbitrary._
import org.scalacheck.Prop.forAll
import org.scalacheck._

object MapViewProperties extends Properties("MapView") {

type K = Int
type V = Int
type T = (K, V)

val x = MapView.from(List(1 -> ""))

property("filter behaves like Map.filter") = forAll { (m: Map[K, V], p: ((K, V)) => Boolean, isFlipped: Boolean) =>
if (isFlipped)
m.filterNot(p) == m.view.filterNot(p).toMap
else
m.filter(p) == m.view.filter(p).toMap
}

property("concat behaves like Map.concat") = forAll { (m0: Map[K, V], m1: Map[K, V]) =>
val strictStrict = m0 concat m1
val strictView = m0 concat m1.view
val viewStrict = (m0.view concat m1).toMap
val viewView = (m0.view concat m1.view).toMap

strictStrict == strictView &&
strictView == viewStrict &&
viewStrict == viewView
}
property("++ behaves like Map.++") = forAll { (m0: Map[K, V], m1: Map[K, V]) =>
val strictStrict = m0 ++ m1
val strictView = m0 ++ m1.view
val viewStrict = (m0.view ++ m1).toMap
val viewView = (m0.view ++ m1.view).toMap

strictStrict == strictView &&
strictView == viewStrict &&
viewStrict == viewView
}
property("partition behaves like Map.partition") = forAll { (m: Map[K, V], p: ((K, V)) => Boolean) =>
val strict = m.partition(p)
val (viewA, viewB) = m.view.partition(p)
strict == (viewA.toMap, viewB.toMap)
}

}

0 comments on commit 09d43ff

Please sign in to comment.