forked from vpnhood/VpnHood
-
Notifications
You must be signed in to change notification settings - Fork 0
/
StreamCryptor.cs
107 lines (87 loc) · 3.48 KB
/
StreamCryptor.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
using VpnHood.Common.Utils;
namespace VpnHood.Tunneling;
public class StreamCryptor : AsyncStreamDecorator
{
private readonly BufferCryptor _bufferCryptor;
private readonly bool _leaveOpen;
private readonly long _maxCipherCount;
private readonly Stream _stream;
private readonly bool _encryptInGivenBuffer;
private long _readCount;
private long _writeCount;
private StreamCryptor(Stream stream, byte[] key, long maxCipherCount,
bool leaveOpen, bool encryptInGivenBuffer)
: base(stream, leaveOpen)
{
if (key is null) throw new ArgumentNullException(nameof(key));
_stream = stream ?? throw new ArgumentNullException(nameof(stream));
_bufferCryptor = new BufferCryptor(key);
_maxCipherCount = maxCipherCount;
_leaveOpen = leaveOpen;
_encryptInGivenBuffer = encryptInGivenBuffer;
}
public override bool CanSeek => false;
public static StreamCryptor Create(Stream stream, byte[] key, byte[]? salt = null, long maxCipherPos = long.MaxValue,
bool leaveOpen = false, bool encryptInGivenBuffer = true)
{
if (stream is null) throw new ArgumentNullException(nameof(stream));
if (key is null) throw new ArgumentNullException(nameof(key));
var encKey = key;
// apply salt if salt exists
if (salt != null)
{
if (key.Length != salt.Length)
throw new Exception($"{nameof(key)} length and {nameof(salt)} length is not same.");
encKey = (byte[])key.Clone();
for (var i = 0; i < encKey.Length; i++)
encKey[i] ^= salt[i];
}
return new StreamCryptor(stream, encKey, maxCipherPos, leaveOpen, encryptInGivenBuffer);
}
public void Decrypt(byte[] buffer, int offset, int count)
{
var cipherCount = Math.Min(count, _maxCipherCount - _readCount);
if (cipherCount > 0)
{
lock (_bufferCryptor)
_bufferCryptor.Cipher(buffer, offset, (int)cipherCount, _readCount);
_readCount += count;
}
}
public void Encrypt(byte[] buffer, int offset, int count)
{
var cipherCount = Math.Min(count, _maxCipherCount - _writeCount);
if (cipherCount > 0)
{
lock (_bufferCryptor)
_bufferCryptor.Cipher(buffer, offset, (int)cipherCount, _writeCount);
_writeCount += cipherCount;
}
}
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count,
CancellationToken cancellationToken)
{
var readCount = await _stream.ReadAsync(buffer, offset, count, cancellationToken).VhConfigureAwait();
Decrypt(buffer, offset, readCount);
return readCount;
}
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
if (_encryptInGivenBuffer)
{
Encrypt(buffer, offset, count);
return _stream.WriteAsync(buffer, offset, count, cancellationToken);
}
var copyBuffer = buffer[offset..count];
Encrypt(copyBuffer, 0, copyBuffer.Length);
return _stream.WriteAsync(copyBuffer, 0, copyBuffer.Length, cancellationToken);
}
public override async ValueTask DisposeAsync()
{
lock (_bufferCryptor)
_bufferCryptor.Dispose();
if (!_leaveOpen)
await _stream.DisposeAsync().VhConfigureAwait();
await base.DisposeAsync().VhConfigureAwait();
}
}