Skip to content

Commit

Permalink
mapAnnotations takes new va type. (hail-is#2188)
Browse files Browse the repository at this point in the history
The previous idiom was mapAnnotations(...).copy(vaSignature =
newVASignature), but this results (temporarily) in a VDS with an
incorrect va type that which causes problems for downstream changes
and assertions like typecheck.
  • Loading branch information
cseed authored and jigold committed Sep 5, 2017
1 parent 9e4a5bb commit fb436a0
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 73 deletions.
4 changes: 2 additions & 2 deletions src/main/scala/is/hail/methods/ImputeSexPlink.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ object ImputeSexPlink {
else
v.contig == "X" || v.contig == "23" || v.contig == "25"
}
.mapAnnotations { case (v, va, gs) =>
.mapAnnotations(TFloat64, { case (v, va, gs) =>
query.map(_.apply(va))
.getOrElse {
var nAlt = 0
Expand All @@ -65,7 +65,7 @@ object ImputeSexPlink {
else
null
}
}
})
.filterVariants { case (v, va, _) => Option(va).exists(_.asInstanceOf[Double] > mafThreshold) }
.aggregateBySampleWithAll(new InbreedingCombiner)({ case (ibc, _, va, _, _, gt) =>
ibc.merge(gt, va.asInstanceOf[Double])
Expand Down
9 changes: 4 additions & 5 deletions src/main/scala/is/hail/methods/LinearMixedRegression.scala
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,14 @@ object LinearMixedRegression {

val scalerLMMBc = sc.broadcast(scalerLMM)

vds2.mapAnnotations { case (v, va, gs) =>
vds2.mapAnnotations(newVAS, { case (v, va, gs) =>
val x: Vector[Double] =
if (!useDosages) {
val x0 = RegressionUtils.hardCalls(gs, n, sampleMaskBc.value)
if (x0.used <= sparsityThreshold * n) x0 else x0.toDenseVector
} else
RegressionUtils.dosages(gs, completeSampleIndexBc.value)

// TODO constant checking to be removed in 0.2
val nonConstant = useDosages || !RegressionUtils.constantVector(x)

Expand All @@ -178,9 +178,8 @@ object LinearMixedRegression {
val newAnnotation = inserter(va, lmmregAnnot)
assert(newVAS.typeCheck(newAnnotation))
newAnnotation
}.copy(vaSignature = newVAS)
}
else
})
} else
vds2
}

Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/is/hail/variant/VariantDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class VariantDatasetFunctions(private val vds: VariantDataset) extends AnyVal {

val aggregateOption = Aggregators.buildVariantAggregations(vds, ec)

vds.mapAnnotations { case (v, va, gs) =>
vds.mapAnnotations(finalType, { case (v, va, gs) =>

val annotations = SplitMulti.split(v, va, gs,
propagateGQ = propagateGQ,
Expand All @@ -103,7 +103,7 @@ class VariantDatasetFunctions(private val vds: VariantDataset) extends AnyVal {
inserter(va, annotations.map(_ (i)).toArray[Any]: IndexedSeq[Any])
}

}.copy(vaSignature = finalType)
})
}

def concordance(other: VariantDataset): (IndexedSeq[IndexedSeq[Long]], KeyTable, KeyTable) = {
Expand Down
36 changes: 5 additions & 31 deletions src/main/scala/is/hail/variant/VariantSampleMatrix.scala
Original file line number Diff line number Diff line change
Expand Up @@ -633,15 +633,15 @@ class VariantSampleMatrix[RPK, RK, T >: Null](val hc: HailContext, val metadata:

val aggregateOption = Aggregators.buildVariantAggregations(this, ec)

mapAnnotations { case (v, va, gs) =>
mapAnnotations(finalType, { case (v, va, gs) =>
ec.setAll(localGlobalAnnotation, v, va)

aggregateOption.foreach(f => f(v, va, gs))
f().zip(inserters)
.foldLeft(va) { case (va, (v, inserter)) =>
inserter(va, v)
}
}.copy(vaSignature = finalType)
})
}

def annotateVariantsTable(kt: KeyTable, vdsKey: java.util.ArrayList[String],
Expand Down Expand Up @@ -1378,35 +1378,9 @@ class VariantSampleMatrix[RPK, RK, T >: Null](val hc: HailContext, val metadata:
}
}

def mapAnnotations(f: (RK, Annotation, Iterable[T]) => Annotation): VariantSampleMatrix[RPK, RK, T] =
copy(rdd = rdd.mapValuesWithKey { case (v, (va, gs)) => (f(v, va, gs), gs) })

def mapAnnotationsWithAggregate[U](zeroValue: U, newVAS: Type)(
seqOp: (U, RK, Annotation, Annotation, Annotation, T) => U,
combOp: (U, U) => U,
mapOp: (Annotation, U) => Annotation)
(implicit uct: ClassTag[U]): VariantSampleMatrix[RPK, RK, T] = {

// Serialize the zero value to a byte array so that we can apply a new clone of it on each key
val zeroBuffer = SparkEnv.get.serializer.newInstance().serialize(zeroValue)
val zeroArray = new Array[Byte](zeroBuffer.limit)
zeroBuffer.get(zeroArray)

val localSampleIdsBc = sampleIdsBc
val localSampleAnnotationsBc = sampleAnnotationsBc

copy(vaSignature = newVAS,
rdd = rdd.mapValuesWithKey { case (v, (va, gs)) =>
val serializer = SparkEnv.get.serializer.newInstance()
val zeroValue = serializer.deserialize[U](ByteBuffer.wrap(zeroArray))

(mapOp(va, gs.iterator
.zip(localSampleIdsBc.value.iterator
.zip(localSampleAnnotationsBc.value.iterator)).foldLeft(zeroValue) { case (acc, (g, (s, sa))) =>
seqOp(acc, v, va, s, sa, g)
}), gs)
})
}
def mapAnnotations(newVASignature: Type, f: (RK, Annotation, Iterable[T]) => Annotation): VariantSampleMatrix[RPK, RK, T] =
copy(vaSignature = newVASignature,
rdd = rdd.mapValuesWithKey { case (v, (va, gs)) => (f(v, va, gs), gs) })

def mapPartitionsWithAll[U](f: Iterator[(RK, Annotation, Annotation, Annotation, T)] => Iterator[U])
(implicit uct: ClassTag[U]): RDD[U] = {
Expand Down
39 changes: 13 additions & 26 deletions src/test/scala/is/hail/annotations/AnnotationsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,14 @@ class AnnotationsSuite extends SparkSuite {

// clear everything
val (emptyS, d1) = vds.deleteVA()
vds = vds.mapAnnotations((v, va, gs) => d1(va))
.copy(vaSignature = emptyS)
vds = vds.mapAnnotations(emptyS, (v, va, gs) => d1(va))
assert(emptyS == TStruct.empty)

// add to the first layer
val toAdd = 5
val toAddSig = TInt32
val (s1, i1) = vds.vaSignature.insert(toAddSig, "I1")
vds = vds.mapAnnotations((v, va, gs) => i1(va, toAdd))
.copy(vaSignature = s1)
vds = vds.mapAnnotations(s1, (v, va, gs) => i1(va, toAdd))
assert(vds.vaSignature.schema ==
StructType(Array(StructField("I1", IntegerType))))

Expand All @@ -174,8 +172,7 @@ class AnnotationsSuite extends SparkSuite {
val toAdd2 = "test"
val toAdd2Sig = TString
val (s2, i2) = vds.vaSignature.insert(toAdd2Sig, "S1")
vds = vds.mapAnnotations((v, va, gs) => i2(va, toAdd2))
.copy(vaSignature = s2)
vds = vds.mapAnnotations(s2, (v, va, gs) => i2(va, toAdd2))
assert(vds.vaSignature.schema ==
StructType(Array(
StructField("I1", IntegerType),
Expand All @@ -191,8 +188,7 @@ class AnnotationsSuite extends SparkSuite {
val toAdd3Sig = TStruct("I2" -> TInt32,
"I3" -> TInt32)
val (s3, i3) = vds.vaSignature.insert(toAdd3Sig, "I1")
vds = vds.mapAnnotations((v, va, gs) => i3(va, toAdd3))
.copy(vaSignature = s3)
vds = vds.mapAnnotations(s3, (v, va, gs) => i3(va, toAdd3))
assert(vds.vaSignature.schema ==
StructType(Array(
StructField("I1", StructType(Array(
Expand All @@ -216,8 +212,7 @@ class AnnotationsSuite extends SparkSuite {
val toAdd4 = "dummy"
val toAdd4Sig = TString
val (s4, i4) = vds.insertVA(toAdd4Sig, "a", "b", "c", "d", "e")
vds = vds.mapAnnotations((v, va, gs) => i4(va, toAdd4))
.copy(vaSignature = s4)
vds = vds.mapAnnotations(s4, (v, va, gs) => i4(va, toAdd4))
assert(vds.vaSignature.schema ==
StructType(Array(
StructField("I1", toAdd3Sig.schema),
Expand All @@ -237,8 +232,7 @@ class AnnotationsSuite extends SparkSuite {
val toAdd5 = "dummy2"
val toAdd5Sig = TString
val (s5, i5) = vds.insertVA(toAdd5Sig, "a", "b", "c", "f")
vds = vds.mapAnnotations((v, va, gs) => i5(va, toAdd5))
.copy(vaSignature = s5)
vds = vds.mapAnnotations(s5, (v, va, gs) => i5(va, toAdd5))

assert(vds.vaSignature.schema ==
StructType(Array(
Expand All @@ -259,8 +253,7 @@ class AnnotationsSuite extends SparkSuite {
val toAdd6 = "dummy3"
val toAdd6Sig = TString
val (s6, i6) = vds.insertVA(toAdd6Sig, "a", "b", "c", "d")
vds = vds.mapAnnotations((v, va, gs) => i6(va, toAdd6))
.copy(vaSignature = s6)
vds = vds.mapAnnotations(s6, (v, va, gs) => i6(va, toAdd6))

assert(vds.vaSignature.schema ==
StructType(Array(
Expand All @@ -280,8 +273,7 @@ class AnnotationsSuite extends SparkSuite {
val toAdd7 = "dummy4"
val toAdd7Sig = TString
val (s7, i7) = vds.insertVA(toAdd7Sig, "a", "c")
vds = vds.mapAnnotations((v, va, gs) => i7(va, toAdd7))
.copy(vaSignature = s7)
vds = vds.mapAnnotations(s7, (v, va, gs) => i7(va, toAdd7))

assert(vds.vaSignature.schema ==
StructType(Array(
Expand All @@ -300,8 +292,7 @@ class AnnotationsSuite extends SparkSuite {

// delete a.b.c and ensure that b is deleted and a.c gets shifted over
val (s8, d2) = vds.deleteVA("a", "b", "c")
vds = vds.mapAnnotations((v, va, gs) => d2(va))
.copy(vaSignature = s8)
vds = vds.mapAnnotations(s8, (v, va, gs) => d2(va))
assert(vds.vaSignature.schema ==
StructType(Array(
StructField("I1", toAdd3Sig.schema),
Expand All @@ -315,8 +306,7 @@ class AnnotationsSuite extends SparkSuite {

// delete that part of the tree
val (s9, d3) = vds.deleteVA("a")
vds = vds.mapAnnotations((v, va, gs) => d3(va))
.copy(vaSignature = s9)
vds = vds.mapAnnotations(s9, (v, va, gs) => d3(va))

assert(vds.vaSignature.schema ==
StructType(Array(
Expand All @@ -329,8 +319,7 @@ class AnnotationsSuite extends SparkSuite {

// delete the first thing in the row and make sure things are shifted over correctly
val (s10, d4) = vds.deleteVA("I1")
vds = vds.mapAnnotations((v, va, gs) => d4(va))
.copy(vaSignature = s10)
vds = vds.mapAnnotations(s10, (v, va, gs) => d4(va))

assert(vds.vaSignature.schema ==
StructType(Array(
Expand All @@ -343,8 +332,7 @@ class AnnotationsSuite extends SparkSuite {
val toAdd8 = "dummy"
val toAdd8Sig = TString
val (s11, i8) = vds.insertVA(toAdd8Sig, List[String]())
vds = vds.mapAnnotations((v, va, gs) => i8(va, toAdd8))
.copy(vaSignature = s11)
vds = vds.mapAnnotations(s11, (v, va, gs) => i8(va, toAdd8))

assert(vds.vaSignature.schema == toAdd8Sig.schema)
assert(vds.variantsAndAnnotations.collect()
Expand Down Expand Up @@ -399,8 +387,7 @@ class AnnotationsSuite extends SparkSuite {

val f = tmpDir.createTempFile("testwrite", extension = ".vds")
val (newS, ins) = vds.insertVA(TInt32, "ThisName(won'twork)=====")
vds = vds.mapAnnotations((v, va, gs) => ins(va, 5))
.copy(vaSignature = newS)
vds = vds.mapAnnotations(newS, (v, va, gs) => ins(va, 5))
vds.write(f)

assert(hc.readVDS(f).same(vds))
Expand Down
11 changes: 6 additions & 5 deletions src/test/scala/is/hail/io/ExportVcfSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import org.testng.annotations.Test
import scala.io.Source
import scala.language.postfixOps

class ExportVcfSuite extends SparkSuite {
class ExportVCFSuite extends SparkSuite {

@Test def testSameAsOrigBGzip() {
val vcfFile = "src/test/resources/multipleChromosomes.vcf"
Expand All @@ -40,12 +40,13 @@ class ExportVcfSuite extends SparkSuite {

assert(vdsOrig.same(vdsNew))

val infoSize = vdsNew.vaSignature.getAsOption[TStruct]("info").get.size
val infoType = vdsNew.vaSignature.getAsOption[TStruct]("info").get
val infoSize = infoType.size
val toAdd = Annotation.fromSeq(Array.fill[Any](infoSize)(null))
val (_, inserter) = vdsNew.insertVA(null, "info")

val vdsNewMissingInfo = vdsNew.mapAnnotations((v, va, gs) => inserter(va, toAdd))
val (newVASignature, inserter) = vdsNew.insertVA(infoType, "info")

val vdsNewMissingInfo = vdsNew.mapAnnotations(newVASignature,
(v, va, gs) => inserter(va, toAdd))

vdsNewMissingInfo.exportVCF(outFile2)

Expand Down
3 changes: 1 addition & 2 deletions src/test/scala/is/hail/methods/FilterSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,7 @@ class FilterSuite extends SparkSuite {
var vds = hc.importVCF("src/test/resources/sample.vcf")
val (sigs, i) = vds.insertVA(TInt32, "weird name \t test")
vds = vds
.mapAnnotations((v, va, gs) => i(va, 1000))
.copy(vaSignature = sigs)
.mapAnnotations(sigs, (v, va, gs) => i(va, 1000))
assert(vds.filterVariantsExpr("va.`weird name \\t test` > 500").countVariants() == vds.countVariants)

TestUtils.interceptFatal("invalid escape character.*backtick identifier.*\\\\i")(
Expand Down

0 comments on commit fb436a0

Please sign in to comment.