Skip to content

Commit

Permalink
common.Must2
Browse files Browse the repository at this point in the history
  • Loading branch information
DarienRaymond committed Sep 19, 2017
1 parent 190adf1 commit 8971e69
Show file tree
Hide file tree
Showing 13 changed files with 103 additions and 74 deletions.
6 changes: 6 additions & 0 deletions common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,9 @@ func Must(err error) {
panic(err)
}
}

func Must2(v interface{}, err error) {
if err != nil {
panic(err)
}
}
8 changes: 4 additions & 4 deletions proxy/shadowsocks/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
}
}

func (v *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection, dispatcher dispatcher.Interface) error {
func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection, dispatcher dispatcher.Interface) error {
udpServer := udp.NewDispatcher(dispatcher)

reader := buf.NewReader(conn)
Expand All @@ -81,7 +81,7 @@ func (v *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
}

for _, payload := range mpayload {
request, data, err := DecodeUDPPacket(v.user, payload)
request, data, err := DecodeUDPPacket(s.user, payload)
if err != nil {
if source, ok := proxy.SourceFromContext(ctx); ok {
log.Trace(newError("dropping invalid UDP packet from: ", source).Base(err))
Expand All @@ -91,13 +91,13 @@ func (v *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
continue
}

if request.Option.Has(RequestOptionOneTimeAuth) && v.account.OneTimeAuth == Account_Disabled {
if request.Option.Has(RequestOptionOneTimeAuth) && s.account.OneTimeAuth == Account_Disabled {
log.Trace(newError("client payload enables OTA but server doesn't allow it"))
payload.Release()
continue
}

if !request.Option.Has(RequestOptionOneTimeAuth) && v.account.OneTimeAuth == Account_Enabled {
if !request.Option.Has(RequestOptionOneTimeAuth) && s.account.OneTimeAuth == Account_Enabled {
log.Trace(newError("client payload disables OTA but server forces it"))
payload.Release()
continue
Expand Down
20 changes: 10 additions & 10 deletions proxy/vmess/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,32 @@ type InternalAccount struct {
Security protocol.Security
}

func (v *InternalAccount) AnyValidID() *protocol.ID {
if len(v.AlterIDs) == 0 {
return v.ID
func (a *InternalAccount) AnyValidID() *protocol.ID {
if len(a.AlterIDs) == 0 {
return a.ID
}
return v.AlterIDs[dice.Roll(len(v.AlterIDs))]
return a.AlterIDs[dice.Roll(len(a.AlterIDs))]
}

func (v *InternalAccount) Equals(account protocol.Account) bool {
func (a *InternalAccount) Equals(account protocol.Account) bool {
vmessAccount, ok := account.(*InternalAccount)
if !ok {
return false
}
// TODO: handle AlterIds difference
return v.ID.Equals(vmessAccount.ID)
return a.ID.Equals(vmessAccount.ID)
}

func (v *Account) AsAccount() (protocol.Account, error) {
id, err := uuid.ParseString(v.Id)
func (a *Account) AsAccount() (protocol.Account, error) {
id, err := uuid.ParseString(a.Id)
if err != nil {
log.Trace(newError("failed to parse ID").Base(err).AtError())
return nil, err
}
protoID := protocol.NewID(id)
return &InternalAccount{
ID: protoID,
AlterIDs: protocol.NewAlterIDs(protoID, uint16(v.AlterId)),
Security: v.SecuritySettings.AsSecurity(),
AlterIDs: protocol.NewAlterIDs(protoID, uint16(a.AlterId)),
Security: a.SecuritySettings.AsSecurity(),
}, nil
}
11 changes: 6 additions & 5 deletions proxy/vmess/encoding/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@ import (
"crypto/md5"
"hash/fnv"

"golang.org/x/crypto/sha3"

"v2ray.com/core/common"
"v2ray.com/core/common/serial"

"golang.org/x/crypto/sha3"
)

// Authenticate authenticates a byte array using Fnv hash.
func Authenticate(b []byte) uint32 {
fnv1hash := fnv.New32a()
fnv1hash.Write(b)
common.Must2(fnv1hash.Write(b))
return fnv1hash.Sum32()
}

Expand Down Expand Up @@ -81,7 +82,7 @@ type ShakeSizeParser struct {

func NewShakeSizeParser(nonce []byte) *ShakeSizeParser {
shake := sha3.NewShake128()
shake.Write(nonce)
common.Must2(shake.Write(nonce))
return &ShakeSizeParser{
shake: shake,
}
Expand All @@ -92,7 +93,7 @@ func (*ShakeSizeParser) SizeBytes() int {
}

func (s *ShakeSizeParser) next() uint16 {
s.shake.Read(s.buffer[:])
common.Must2(s.shake.Read(s.buffer[:]))
return serial.BytesToUint16(s.buffer[:])
}

Expand Down
49 changes: 24 additions & 25 deletions proxy/vmess/encoding/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"golang.org/x/crypto/chacha20poly1305"

"v2ray.com/core/app/log"
"v2ray.com/core/common"
"v2ray.com/core/common/buf"
"v2ray.com/core/common/crypto"
"v2ray.com/core/common/dice"
Expand Down Expand Up @@ -43,7 +44,7 @@ type ClientSession struct {
// NewClientSession creates a new ClientSession.
func NewClientSession(idHash protocol.IDHash) *ClientSession {
randomBytes := make([]byte, 33) // 16 + 16 + 1
rand.Read(randomBytes)
common.Must2(rand.Read(randomBytes))

session := &ClientSession{}
session.requestBodyKey = randomBytes[:16]
Expand All @@ -58,22 +59,22 @@ func NewClientSession(idHash protocol.IDHash) *ClientSession {
return session
}

func (v *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) {
func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) {
timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)()
account, err := header.User.GetTypedAccount()
if err != nil {
log.Trace(newError("failed to get user account: ", err).AtError())
return
}
idHash := v.idHash(account.(*vmess.InternalAccount).AnyValidID().Bytes())
idHash.Write(timestamp.Bytes(nil))
writer.Write(idHash.Sum(nil))
idHash := c.idHash(account.(*vmess.InternalAccount).AnyValidID().Bytes())
common.Must2(idHash.Write(timestamp.Bytes(nil)))
common.Must2(writer.Write(idHash.Sum(nil)))

buffer := make([]byte, 0, 512)
buffer = append(buffer, Version)
buffer = append(buffer, v.requestBodyIV...)
buffer = append(buffer, v.requestBodyKey...)
buffer = append(buffer, v.responseHeader, byte(header.Option))
buffer = append(buffer, c.requestBodyIV...)
buffer = append(buffer, c.requestBodyKey...)
buffer = append(buffer, c.responseHeader, byte(header.Option))
padingLen := dice.Roll(16)
if header.Security.Is(protocol.SecurityType_LEGACY) {
// Disable padding in legacy mode for a smooth transition.
Expand All @@ -100,29 +101,27 @@ func (v *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ

if padingLen > 0 {
pading := make([]byte, padingLen)
rand.Read(pading)
common.Must2(rand.Read(pading))
buffer = append(buffer, pading...)
}

fnv1a := fnv.New32a()
fnv1a.Write(buffer)
common.Must2(fnv1a.Write(buffer))

buffer = fnv1a.Sum(buffer)

timestampHash := md5.New()
timestampHash.Write(hashTimestamp(timestamp))
common.Must2(timestampHash.Write(hashTimestamp(timestamp)))
iv := timestampHash.Sum(nil)
aesStream := crypto.NewAesEncryptionStream(account.(*vmess.InternalAccount).ID.CmdKey(), iv)
aesStream.XORKeyStream(buffer, buffer)
writer.Write(buffer)

return
common.Must2(writer.Write(buffer))
}

func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, writer io.Writer) buf.Writer {
func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, writer io.Writer) buf.Writer {
var sizeParser crypto.ChunkSizeEncoder = crypto.PlainChunkSizeParser{}
if request.Option.Has(protocol.RequestOptionChunkMasking) {
sizeParser = NewShakeSizeParser(v.requestBodyIV)
sizeParser = NewShakeSizeParser(c.requestBodyIV)
}
if request.Security.Is(protocol.SecurityType_NONE) {
if request.Option.Has(protocol.RequestOptionChunkStream) {
Expand All @@ -141,7 +140,7 @@ func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write
}

if request.Security.Is(protocol.SecurityType_LEGACY) {
aesStream := crypto.NewAesEncryptionStream(v.requestBodyKey, v.requestBodyIV)
aesStream := crypto.NewAesEncryptionStream(c.requestBodyKey, c.requestBodyIV)
cryptionWriter := crypto.NewCryptionWriter(aesStream, writer)
if request.Option.Has(protocol.RequestOptionChunkStream) {
auth := &crypto.AEADAuthenticator{
Expand All @@ -156,13 +155,13 @@ func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write
}

if request.Security.Is(protocol.SecurityType_AES128_GCM) {
block, _ := aes.NewCipher(v.requestBodyKey)
block, _ := aes.NewCipher(c.requestBodyKey)
aead, _ := cipher.NewGCM(block)

auth := &crypto.AEADAuthenticator{
AEAD: aead,
NonceGenerator: &ChunkNonceGenerator{
Nonce: append([]byte(nil), v.requestBodyIV...),
Nonce: append([]byte(nil), c.requestBodyIV...),
Size: aead.NonceSize(),
},
AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
Expand All @@ -171,12 +170,12 @@ func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write
}

if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) {
aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(v.requestBodyKey))
aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(c.requestBodyKey))

auth := &crypto.AEADAuthenticator{
AEAD: aead,
NonceGenerator: &ChunkNonceGenerator{
Nonce: append([]byte(nil), v.requestBodyIV...),
Nonce: append([]byte(nil), c.requestBodyIV...),
Size: aead.NonceSize(),
},
AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
Expand Down Expand Up @@ -299,8 +298,8 @@ type ChunkNonceGenerator struct {
count uint16
}

func (v *ChunkNonceGenerator) Next() []byte {
serial.Uint16ToBytes(v.count, v.Nonce[:0])
v.count++
return v.Nonce[:v.Size]
func (g *ChunkNonceGenerator) Next() []byte {
serial.Uint16ToBytes(g.count, g.Nonce[:0])
g.count++
return g.Nonce[:g.Size]
}
23 changes: 12 additions & 11 deletions proxy/vmess/encoding/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package encoding
import (
"io"

"v2ray.com/core/common"
"v2ray.com/core/common/buf"
"v2ray.com/core/common/net"
"v2ray.com/core/common/protocol"
Expand Down Expand Up @@ -45,8 +46,8 @@ func MarshalCommand(command interface{}, writer io.Writer) error {
return ErrCommandTooLarge
}

writer.Write([]byte{cmdID, byte(len), byte(auth >> 24), byte(auth >> 16), byte(auth >> 8), byte(auth)})
writer.Write(buffer.Bytes())
common.Must2(writer.Write([]byte{cmdID, byte(len), byte(auth >> 24), byte(auth >> 16), byte(auth >> 8), byte(auth)}))
common.Must2(writer.Write(buffer.Bytes()))
return nil
}

Expand Down Expand Up @@ -78,7 +79,7 @@ type CommandFactory interface {
type CommandSwitchAccountFactory struct {
}

func (v *CommandSwitchAccountFactory) Marshal(command interface{}, writer io.Writer) error {
func (f *CommandSwitchAccountFactory) Marshal(command interface{}, writer io.Writer) error {
cmd, ok := command.(*protocol.CommandSwitchAccount)
if !ok {
return ErrCommandTypeMismatch
Expand All @@ -88,25 +89,25 @@ func (v *CommandSwitchAccountFactory) Marshal(command interface{}, writer io.Wri
if cmd.Host != nil {
hostStr = cmd.Host.String()
}
writer.Write([]byte{byte(len(hostStr))})
common.Must2(writer.Write([]byte{byte(len(hostStr))}))

if len(hostStr) > 0 {
writer.Write([]byte(hostStr))
common.Must2(writer.Write([]byte(hostStr)))
}

writer.Write(cmd.Port.Bytes(nil))
common.Must2(writer.Write(cmd.Port.Bytes(nil)))

idBytes := cmd.ID.Bytes()
writer.Write(idBytes)
common.Must2(writer.Write(idBytes))

writer.Write(serial.Uint16ToBytes(cmd.AlterIds, nil))
writer.Write([]byte{byte(cmd.Level)})
common.Must2(writer.Write(serial.Uint16ToBytes(cmd.AlterIds, nil)))
common.Must2(writer.Write([]byte{byte(cmd.Level)}))

writer.Write([]byte{cmd.ValidMin})
common.Must2(writer.Write([]byte{cmd.ValidMin}))
return nil
}

func (v *CommandSwitchAccountFactory) Unmarshal(data []byte) (interface{}, error) {
func (f *CommandSwitchAccountFactory) Unmarshal(data []byte) (interface{}, error) {
cmd := new(protocol.CommandSwitchAccount)
if len(data) == 0 {
return nil, newError("insufficient length.")
Expand Down
9 changes: 5 additions & 4 deletions proxy/vmess/encoding/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"golang.org/x/crypto/chacha20poly1305"
"v2ray.com/core/common"
"v2ray.com/core/common/buf"
"v2ray.com/core/common/crypto"
"v2ray.com/core/common/net"
Expand Down Expand Up @@ -126,7 +127,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
}

timestampHash := md5.New()
timestampHash.Write(hashTimestamp(timestamp))
common.Must2(timestampHash.Write(hashTimestamp(timestamp)))
iv := timestampHash.Sum(nil)
account, err := user.GetTypedAccount()
if err != nil {
Expand Down Expand Up @@ -220,7 +221,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
}

fnv1a := fnv.New32a()
fnv1a.Write(buffer[:bufferLen])
common.Must2(fnv1a.Write(buffer[:bufferLen]))
actualHash := fnv1a.Sum32()
expectedHash := serial.BytesToUint32(buffer[bufferLen : bufferLen+4])

Expand Down Expand Up @@ -314,10 +315,10 @@ func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, wr
encryptionWriter := crypto.NewCryptionWriter(aesStream, writer)
s.responseWriter = encryptionWriter

encryptionWriter.Write([]byte{s.responseHeader, byte(header.Option)})
common.Must2(encryptionWriter.Write([]byte{s.responseHeader, byte(header.Option)}))
err := MarshalCommand(header.Command, encryptionWriter)
if err != nil {
encryptionWriter.Write([]byte{0x00, 0x00})
common.Must2(encryptionWriter.Write([]byte{0x00, 0x00}))
}
}

Expand Down
2 changes: 1 addition & 1 deletion proxy/vmess/inbound/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func transferRequest(timer signal.ActivityTimer, session *encoding.ServerSession

bodyReader := session.DecodeRequestBody(request, input)
if err := buf.Copy(bodyReader, output, buf.UpdateActivity(timer)); err != nil {
return err
return newError("failed to transfer request").Base(err)
}
return nil
}
Expand Down
8 changes: 4 additions & 4 deletions proxy/vmess/outbound/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"v2ray.com/core/proxy/vmess"
)

func (v *Handler) handleSwitchAccount(cmd *protocol.CommandSwitchAccount) {
func (h *Handler) handleSwitchAccount(cmd *protocol.CommandSwitchAccount) {
account := &vmess.Account{
Id: cmd.ID.String(),
AlterId: uint32(cmd.AlterIds),
Expand All @@ -25,16 +25,16 @@ func (v *Handler) handleSwitchAccount(cmd *protocol.CommandSwitchAccount) {
}
dest := net.TCPDestination(cmd.Host, cmd.Port)
until := time.Now().Add(time.Duration(cmd.ValidMin) * time.Minute)
v.serverList.AddServer(protocol.NewServerSpec(dest, protocol.BeforeTime(until), user))
h.serverList.AddServer(protocol.NewServerSpec(dest, protocol.BeforeTime(until), user))
}

func (v *Handler) handleCommand(dest net.Destination, cmd protocol.ResponseCommand) {
func (h *Handler) handleCommand(dest net.Destination, cmd protocol.ResponseCommand) {
switch typedCommand := cmd.(type) {
case *protocol.CommandSwitchAccount:
if typedCommand.Host == nil {
typedCommand.Host = dest.Address
}
v.handleSwitchAccount(typedCommand)
h.handleSwitchAccount(typedCommand)
default:
}
}
Loading

0 comments on commit 8971e69

Please sign in to comment.