Skip to content

Commit

Permalink
[jwt] Support fractional dates (lestrrat-go#732)
Browse files Browse the repository at this point in the history
* Add fractional time support

* Appease linter

* Appease linter

* avoid time.Parse

* tweak
  • Loading branch information
lestrrat authored May 6, 2022
1 parent ec80811 commit e0fac29
Show file tree
Hide file tree
Showing 7 changed files with 343 additions and 34 deletions.
104 changes: 88 additions & 16 deletions jwt/internal/types/date.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,20 @@ package types
import (
"fmt"
"strconv"
"strings"
"time"

"github.com/lestrrat-go/jwx/v2/internal/json"
)

const (
DefaultPrecision uint32 = 0 // second level
MaxPrecision uint32 = 9 // nanosecond level
)

var ParsePrecision = DefaultPrecision
var FormatPrecision = DefaultPrecision

// NumericDate represents the date format used in the 'nbf' claim
type NumericDate struct {
time.Time
Expand All @@ -20,7 +29,7 @@ func (n *NumericDate) Get() time.Time {
return n.Time
}

func numericToTime(v interface{}, t *time.Time) bool {
func intToTime(v interface{}, t *time.Time) bool {
var n int64
switch x := v.(type) {
case int64:
Expand All @@ -33,10 +42,6 @@ func numericToTime(v interface{}, t *time.Time) bool {
n = int64(x)
case int:
n = int64(x)
case float32:
n = int64(x)
case float64:
n = int64(x)
default:
return false
}
Expand All @@ -45,41 +50,108 @@ func numericToTime(v interface{}, t *time.Time) bool {
return true
}

func parseNumericString(x string) (time.Time, error) {
var t time.Time // empty time for empty return value
var fractional string
whole := x
if i := strings.IndexRune(x, '.'); i > 0 {
if ParsePrecision > 0 && len(x) > i+1 {
fractional = x[i+1:] // everything after the '.'
if int(ParsePrecision) < len(fractional) {
// Remove insignificant digits
fractional = fractional[:int(ParsePrecision)]
}
// Replace missing fractional diits with zeros
for len(fractional) < int(MaxPrecision) {
fractional = fractional + "0"
}
}
whole = x[:i]
}
n, err := strconv.ParseInt(whole, 10, 64)
if err != nil {
return t, fmt.Errorf(`failed to parse whole value %q: %w`, whole, err)
}
var nsecs int64
if fractional != "" {
v, err := strconv.ParseInt(fractional, 10, 64)
if err != nil {
return t, fmt.Errorf(`failed to parse fractional value %q: %w`, fractional, err)
}
nsecs = v
}

return time.Unix(n, nsecs).UTC(), nil
}

func (n *NumericDate) Accept(v interface{}) error {
var t time.Time

switch x := v.(type) {
case string:
i, err := strconv.ParseInt(x[:], 10, 64)
case float32:
tv, err := parseNumericString(fmt.Sprintf(`%.9f`, x))
if err != nil {
return fmt.Errorf(`invalid epoch value %#v`, x)
return fmt.Errorf(`failed to accept float32 %.9f: %w`, x, err)
}
t = time.Unix(i, 0)

t = tv
case float64:
tv, err := parseNumericString(fmt.Sprintf(`%.9f`, x))
if err != nil {
return fmt.Errorf(`failed to accept float32 %.9f: %w`, x, err)
}
t = tv
case json.Number:
intval, err := x.Int64()
tv, err := parseNumericString(x.String())
if err != nil {
return fmt.Errorf(`failed to convert json value %#v to int64: %w`, x, err)
return fmt.Errorf(`failed to accept json.Number %q: %w`, x.String(), err)
}
t = time.Unix(intval, 0)
t = tv
case string:
tv, err := parseNumericString(x)
if err != nil {
return fmt.Errorf(`failed to accept string %q: %w`, x, err)
}
t = tv
case time.Time:
t = x
default:
if !numericToTime(v, &t) {
if !intToTime(v, &t) {
return fmt.Errorf(`invalid type %T`, v)
}
}
n.Time = t.UTC()
return nil
}

func (n NumericDate) String() string {
if FormatPrecision == 0 {
return strconv.FormatInt(n.Unix(), 10)
}

// This is cheating,but it's better (easier) than doing floating point math
// We basically munge with strings after formatting an integer balue
// for nanoseconds since epoch
s := strconv.FormatInt(n.UnixNano(), 10)
for len(s) < int(MaxPrecision) {
s = "0" + s
}

slwhole := len(s) - int(MaxPrecision)
s = s[:slwhole] + "." + s[slwhole:slwhole+int(FormatPrecision)]
if s[0] == '.' {
s = "0" + s
}

return s
}

// MarshalJSON translates from internal representation to JSON NumericDate
// See https://tools.ietf.org/html/rfc7519#page-6
func (n *NumericDate) MarshalJSON() ([]byte, error) {
if n.IsZero() {
return json.Marshal(nil)
}
return json.Marshal(n.Unix())

return json.Marshal(n.String())
}

func (n *NumericDate) UnmarshalJSON(data []byte) error {
Expand Down
76 changes: 64 additions & 12 deletions jwt/internal/types/date_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,13 @@ import (
)

func TestDate(t *testing.T) {
t.Parallel()
t.Run("Get from a nil NumericDate", func(t *testing.T) {
t.Parallel()
var n *types.NumericDate
if !assert.Equal(t, time.Time{}, n.Get()) {
return
}
})
t.Run("MarshalJSON with a zero value", func(t *testing.T) {
t.Parallel()
var n *types.NumericDate
buf, err := json.Marshal(n)
if !assert.NoError(t, err, `json.Marshal against a zero value should succeed`) {
Expand All @@ -33,18 +30,73 @@ func TestDate(t *testing.T) {
return
}
})

// This test alters global behavior, and can't be ran in parallel
t.Run("Accept values", func(t *testing.T) {
t.Parallel()
// NumericDate allows assignment from various different Go types,
// so that it's easier for the devs, and conversion to/from JSON
// use of "127" is just to allow use of int8's
now := time.Unix(127, 0).UTC()
for _, ut := range []interface{}{int64(127), int32(127), int16(127), int8(127), float32(127), float64(127), json.Number("127")} {
ut := ut
t.Run(fmt.Sprintf("%T", ut), func(t *testing.T) {
t.Parallel()
testcases := []struct {
Input interface{}
Expected time.Time
Precision int
}{
{
Input: int64(127),
Expected: time.Unix(127, 0).UTC(),
},
{
Input: int32(127),
Expected: time.Unix(127, 0).UTC(),
},
{
Input: int16(127),
Expected: time.Unix(127, 0).UTC(),
},
{
Input: int8(127),
Expected: time.Unix(127, 0).UTC(),
},
{
Input: float32(127.11),
Expected: time.Unix(127, 0).UTC(),
},
{
Input: float32(127.11),
Expected: time.Unix(127, 0).UTC(),
},
{
Input: json.Number("127"),
Expected: time.Unix(127, 0).UTC(),
},
{
Input: json.Number("127.11"),
Expected: time.Unix(127, 0).UTC(),
},
{
Input: json.Number("127.11"),
Expected: time.Unix(127, 110000000).UTC(),
Precision: 4,
},
{
Input: json.Number("127.110000011"),
Expected: time.Unix(127, 110000011).UTC(),
Precision: 9,
},
{
Input: json.Number("127.110000011111"),
Expected: time.Unix(127, 110000011).UTC(),
Precision: 9,
},
}

for _, tc := range testcases {
tc := tc
precision := tc.Precision
t.Run(fmt.Sprintf("%v(type=%T, precision=%d)", tc.Input, tc.Input, precision), func(t *testing.T) {
jwt.Settings(jwt.WithNumericDateParsePrecision(precision))

t1 := jwt.New()
err := t1.Set(jwt.IssuedAtKey, ut)
err := t1.Set(jwt.IssuedAtKey, tc.Input)
if !assert.NoError(t, err) {
return
}
Expand All @@ -53,7 +105,7 @@ func TestDate(t *testing.T) {
return
}
realized := v.(time.Time)
if !assert.Equal(t, now, realized) {
if !assert.Equal(t, tc.Expected, realized) {
return
}
})
Expand Down
43 changes: 37 additions & 6 deletions jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,58 @@ import (
"github.com/lestrrat-go/jwx/v2"
"github.com/lestrrat-go/jwx/v2/internal/json"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt/internal/types"
)

// Settings controls global settings that are specific to JWTs.
func Settings(options ...GlobalOption) {
var flattenAudienceBool bool
var parsePrecision = types.MaxPrecision + 1 // illegal value, so we can detect nothing was set
var formatPrecision = types.MaxPrecision + 1 // illegal value, so we can detect nothing was set

//nolint:forcetypeassert
for _, option := range options {
switch option.Ident() {
case identFlattenAudience{}:
flattenAudienceBool = option.Value().(bool)
case identNumericDateParsePrecision{}:
v := option.Value().(int)
// only accept this value if it's in our desired range
if v >= 0 && v <= int(types.MaxPrecision) {
parsePrecision = uint32(v)
}
case identNumericDateFormatPrecision{}:
v := option.Value().(int)
// only accept this value if it's in our desired range
if v >= 0 && v <= int(types.MaxPrecision) {
formatPrecision = uint32(v)
}
}
}

if parsePrecision <= types.MaxPrecision { // remember we set default to max + 1
v := atomic.LoadUint32(&types.ParsePrecision)
if v != parsePrecision {
atomic.CompareAndSwapUint32(&types.ParsePrecision, v, parsePrecision)
}
}

v := atomic.LoadUint32(&json.FlattenAudience)
if (v == 1) != flattenAudienceBool {
var newVal uint32
if flattenAudienceBool {
newVal = 1
if formatPrecision <= types.MaxPrecision { // remember we set default to max + 1
v := atomic.LoadUint32(&types.FormatPrecision)
if v != formatPrecision {
atomic.CompareAndSwapUint32(&types.FormatPrecision, v, formatPrecision)
}
}

{
v := atomic.LoadUint32(&json.FlattenAudience)
if (v == 1) != flattenAudienceBool {
var newVal uint32
if flattenAudienceBool {
newVal = 1
}
atomic.CompareAndSwapUint32(&json.FlattenAudience, v, newVal)
}
atomic.CompareAndSwapUint32(&json.FlattenAudience, v, newVal)
}
}

Expand Down
Loading

0 comments on commit e0fac29

Please sign in to comment.