Skip to content

Commit

Permalink
prioritize X-B3-Sampled over other tracing headers, fix forced sampli…
Browse files Browse the repository at this point in the history
…ng in Spray instrumentation, improve test coverage
  • Loading branch information
levkhomich committed May 1, 2016
1 parent e910d35 commit f8f844f
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,28 +37,32 @@ trait TracingSettings extends GlobalSettings with PlayControllerTracing {
protected def sample(request: RequestHeader): Unit = {
def headerLongValue(tag: String): Option[Long] =
Try(request.headers.get(tag).map(SpanMetadata.idFromString)).getOrElse(None)
def maybeForceSampling: Option[Boolean] = {
def spanId: Long =
headerLongValue(SpanId).getOrElse(Random.nextLong)

val maybeForceSampling =
request.headers.get(Sampled).map(_.toLowerCase) match {
case Some("0") | Some("false") =>
Some(false)
case Some("1") | Some("true") =>
Some(true)
case _ =>
request.headers.get(Flags).flatMap(flags =>
Try((java.lang.Long.parseLong(flags) & DebugFlag) == DebugFlag).toOption)
Try((java.lang.Long.parseLong(flags) & DebugFlag) == DebugFlag).toOption.filter(v => v))
}
}
def spanId: Long =
headerLongValue(SpanId).getOrElse(Random.nextLong)

headerLongValue(TraceId) match {
case Some(traceId) =>
val parentId = headerLongValue(ParentSpanId)
trace.sample(request.tracingId, spanId, parentId, traceId, serviceName,
request.spanName, maybeForceSampling.getOrElse(false))
maybeForceSampling match {
case Some(false) =>
// ignore
case _ =>
maybeForceSampling.foreach(forceSampling =>
trace.sample(request, serviceName, forceSampling))
val forceSampling = maybeForceSampling.getOrElse(false)
headerLongValue(TraceId) match {
case Some(traceId) =>
val parentId = headerLongValue(ParentSpanId)
trace.sample(request.tracingId, spanId, parentId, traceId, serviceName, request.spanName, forceSampling)
case _ =>
trace.sample(request, serviceName, forceSampling)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,13 @@ class PlayTracingSpec extends PlaySpecification with TracingTestCommons with Moc
def disabledLocalSamplingApplication: FakeApplication = FakeApplication(
withRoutes = routes,
withGlobal = Some(new GlobalSettings with TracingSettings),
additionalConfiguration = configuration ++ Map(TracingExtension.AkkaTracingSampleRate -> 0)
additionalConfiguration = configuration ++ Map(TracingExtension.AkkaTracingSampleRate -> Int.MaxValue)
)

"Play tracing" should {
"sample requests" in new WithApplication(fakeApplication) {
val result = route(FakeRequest("GET", TestPath)).map(Await.result(_, defaultAwaitTimeout.duration))
val span = receiveSpan()
success
expectSpans(1)
}

"use play application name as the default end point name" in new WithApplication(fakeApplication) {
Expand Down Expand Up @@ -140,7 +139,7 @@ class PlayTracingSpec extends PlaySpecification with TracingTestCommons with Moc
checkAbsentBinaryAnnotation(span, "request.headers.Excluded")
}

"propagate tracing headers" in new WithApplication(fakeApplication) {
"support trace propagation from external service" in new WithApplication(fakeApplication) {
val traceId = Random.nextLong
val parentId = Random.nextLong

Expand All @@ -161,84 +160,112 @@ class PlayTracingSpec extends PlaySpecification with TracingTestCommons with Moc
checkAnnotation(span, TracingExtension.getStackTrace(npe))
}

val Sampled = Seq("1", "true")
val NotSampled = Seq("0", "false")
Seq("1", "true").foreach { value =>
s"honour upstream's X-B3-Sampled: $value header" in new WithApplication(disabledLocalSamplingApplication) {
val spanId = Random.nextLong
val result = route(FakeRequest("GET", TestPath + "?key=value",
FakeHeaders(Seq(
TracingHeaders.TraceId -> Seq(SpanMetadata.idToString(spanId)),
TracingHeaders.Sampled -> Seq(value)
)), AnyContentAsEmpty)).map(Await.result(_, defaultAwaitTimeout.duration))
expectSpans(1)
}
}

(Sampled ++ NotSampled).foreach { value =>
s"honour upstream's Sampled: $value header" in new WithApplication(fakeApplication) {
Seq("0", "false").foreach { value =>
s"honour upstream's X-B3-Sampled: $value header" in new WithApplication(fakeApplication) {
val spanId = Random.nextLong
val result = route(FakeRequest("GET", TestPath + "?key=value",
FakeHeaders(Seq(
TracingHeaders.TraceId -> Seq(SpanMetadata.idToString(spanId)),
TracingHeaders.Sampled -> Seq(value)
)), AnyContentAsEmpty)).map(Await.result(_, defaultAwaitTimeout.duration))
expectSpans(if (Sampled.contains(value)) 1 else 0)
expectSpans(0)
}
}

"allow forced sampling" in new WithApplication(disabledLocalSamplingApplication) {
Seq("1", "true").foreach { value =>
s"honour upstream's X-B3-Sampled: $value header if X-B3-TraceId is not specified" in new WithApplication(disabledLocalSamplingApplication) {
val spanId = Random.nextLong
val result = route(FakeRequest("GET", TestPath + "?key=value",
FakeHeaders(Seq(
TracingHeaders.Sampled -> Seq(value)
)), AnyContentAsEmpty)).map(Await.result(_, defaultAwaitTimeout.duration))
expectSpans(1)
}
}

"honour upstream's Debug flag" in new WithApplication(disabledLocalSamplingApplication) {
val result = route(FakeRequest("GET", TestPath,
FakeHeaders(Seq(
TracingHeaders.Flags -> Seq("1")
)), AnyContentAsEmpty)).map(Await.result(_, defaultAwaitTimeout.duration))

receiveSpans() must haveLength(1)
expectSpans(1)
}

"ignore broken Flags header" in new WithApplication(disabledLocalSamplingApplication) {
"user regular sampling if X-B3-Flags does not contain Debug flag" in new WithApplication(disabledLocalSamplingApplication) {
val result = route(FakeRequest("GET", TestPath,
FakeHeaders(Seq(
TracingHeaders.Flags -> Seq("broken")
TracingHeaders.Flags -> Seq("2")
)), AnyContentAsEmpty)).map(Await.result(_, defaultAwaitTimeout.duration))

expectSpans(0)
}

"ignore malformed X-B3-Flags header" in new WithApplication(disabledLocalSamplingApplication) {
val result = route(FakeRequest("GET", TestPath,
FakeHeaders(Seq(
TracingHeaders.Flags -> Seq("malformed")
)), AnyContentAsEmpty)).map(Await.result(_, defaultAwaitTimeout.duration))

receiveSpans() must beEmpty
}

"ignore broken TraceId header" in new WithApplication(fakeApplication) {
"ignore malformed X-B3-TraceId header" in new WithApplication(fakeApplication) {
val result = route(FakeRequest("GET", TestPath,
FakeHeaders(Seq(
TracingHeaders.TraceId -> Seq("broken")
TracingHeaders.TraceId -> Seq("malformed")
)), AnyContentAsEmpty)).map(Await.result(_, defaultAwaitTimeout.duration))

receiveSpans() must haveLength(1)
expectSpans(1)
}

"ignore broken SpanId header" in new WithApplication(fakeApplication) {
"ignore malformed X-B3-SpanId header" in new WithApplication(fakeApplication) {
val traceId = Random.nextLong

val result = route(FakeRequest("GET", TestPath,
FakeHeaders(Seq(
TracingHeaders.TraceId -> Seq(SpanMetadata.idToString(traceId)),
TracingHeaders.SpanId -> Seq("broken")
TracingHeaders.SpanId -> Seq("malformed")
)), AnyContentAsEmpty)).map(Await.result(_, defaultAwaitTimeout.duration))

val span = receiveSpan()
span.get_trace_id mustEqual traceId
}

"ignore broken ParentSpanId header" in new WithApplication(fakeApplication) {
"ignore malformed X-B3-ParentSpanId header" in new WithApplication(fakeApplication) {
val traceId = Random.nextLong

val result = route(FakeRequest("GET", TestPath,
FakeHeaders(Seq(
TracingHeaders.TraceId -> Seq(SpanMetadata.idToString(traceId)),
TracingHeaders.ParentSpanId -> Seq("broken")
TracingHeaders.ParentSpanId -> Seq("malformed")
)), AnyContentAsEmpty)).map(Await.result(_, defaultAwaitTimeout.duration))

val span = receiveSpan()
span.get_trace_id mustEqual traceId
}

"ignore broken Sampled header" in new WithApplication(fakeApplication) {
"ignore malformed X-B3-Sampled header" in new WithApplication(fakeApplication) {
val spanId = Random.nextLong
val result = route(FakeRequest("GET", TestPath + "?key=value",
FakeHeaders(Seq(
TracingHeaders.TraceId -> Seq(SpanMetadata.idToString(spanId)),
TracingHeaders.Sampled -> Seq("unexpected value")
TracingHeaders.Sampled -> Seq("malformed")
)), AnyContentAsEmpty)).map(Await.result(_, defaultAwaitTimeout.duration))
expectSpans(1)
}

}

step {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,14 @@ trait BaseTracingDirectives {
import TracingDirectives._

private[this] def tracedEntity[T <: TracingSupport](service: String)(implicit um: FromRequestUnmarshaller[T]): Directive[T :: HNil] =
hextract(ctx => ctx.request.as(um) :: extractSpan(ctx.request) :: ctx.request :: HNil).hflatMap[T :: HNil] {
hextract(ctx => ctx.request.as(um) :: extractSpan(ctx.request, requireTraceId = false) :: ctx.request :: HNil).hflatMap[T :: HNil] {
case Right(value) :: Right(maybeSpan) :: request :: HNil =>
maybeSpan match {
case Some(span) =>
trace.sample(value, span.spanId, span.parentId, span.traceId, service, span.forceSampling)
addHttpAnnotations(value.tracingId, request)
case _ =>
trace.sample(value.tracingId, service, value.spanName)
}
addHttpAnnotations(value.tracingId, request)
hprovide(value :: HNil)
case Right(value) :: Left(headerName) :: request :: HNil =>
reject(MalformedHeaderRejection(headerName, "invalid value"))
Expand Down Expand Up @@ -86,7 +85,7 @@ trait BaseTracingDirectives {
def tracedComplete[T](service: String, rpc: String)(value: => T)(implicit m: ToResponseMarshaller[T]): StandardRoute =
new StandardRoute {
def apply(ctx: RequestContext): Unit = {
extractSpan(ctx.request) match {
extractSpan(ctx.request, requireTraceId = true) match {
case Right(Some(span)) =>
// only requests with explicit tracing headers can be traced here, because we don't have
// any clues about spanId generated for unmarshalled entity
Expand Down Expand Up @@ -167,7 +166,7 @@ private[http] object TracingDirectives {

import TracingHeaders._

def extractSpan(message: HttpMessage): Either[String, Option[SpanMetadata]] = {
def extractSpan(message: HttpMessage, requireTraceId: Boolean): Either[String, Option[SpanMetadata]] = {
def headerStringValue(name: String): Option[String] =
message.headers.find(_.name == name).map(_.value)
def headerLongValue(name: String): Either[String, Option[Long]] =
Expand All @@ -177,25 +176,37 @@ private[http] object TracingDirectives {
case Success(v) =>
Right(v)
}
def isFlagSet(v: String, flag: Long): Boolean =
Try((java.lang.Long.parseLong(v) & flag) == flag).getOrElse(false)
// debug flag forces sampling (see http://git.io/hdEVug)
def forceSampling: Boolean =
headerStringValue(Flags).exists(isFlagSet(_, DebugFlag)) ||
headerStringValue(Sampled).contains("true")
def spanId: Long =
headerLongValue(SpanId).right.toOption.flatten.getOrElse(Random.nextLong)

headerLongValue(TraceId).right.map({
case Some(traceId) =>
headerLongValue(ParentSpanId).right.map { parentId =>
Some(SpanMetadata(traceId, spanId, parentId, forceSampling))
}
case None if forceSampling =>
Right(Some(SpanMetadata(Random.nextLong, spanId, None, true)))
case _ =>
// debug flag forces sampling (see http://git.io/hdEVug)
val maybeForceSampling =
headerStringValue(Sampled).map(_.toLowerCase) match {
case Some("0") | Some("false") =>
Some(false)
case Some("1") | Some("true") =>
Some(true)
case _ =>
headerStringValue(Flags).flatMap(flags =>
Try((java.lang.Long.parseLong(flags) & DebugFlag) == DebugFlag).toOption.filter(v => v))
}

maybeForceSampling match {
case Some(false) =>
Right(None)
}).joinRight
case _ =>
val forceSampling = maybeForceSampling.getOrElse(false)
headerLongValue(TraceId).right.map({
case Some(traceId) =>
headerLongValue(ParentSpanId).right.map { parentId =>
Some(SpanMetadata(traceId, spanId, parentId, forceSampling))
}
case _ if requireTraceId =>
Right(None)
case _ =>
Right(Some(SpanMetadata(Random.nextLong, spanId, None, forceSampling)))
}).joinRight
}
}

}
Expand Down
Loading

0 comments on commit f8f844f

Please sign in to comment.