Skip to content

Commit

Permalink
tests for MSELoss and BERLoss
Browse files Browse the repository at this point in the history
  • Loading branch information
flaport committed May 20, 2020
1 parent 8473a94 commit 65e1e8b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 7 deletions.
17 changes: 10 additions & 7 deletions photontorch/torch_ext/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
None if cutoff_frequency is None else float(cutoff_frequency)
)
self.filter_order = (
None if cutoff_frequency is None else int(filter_order + 0.5)
None if filter_order is None else int(filter_order + 0.5)
)
self.seed = None if seed is None else int(seed + 0.5)
self.device = torch.device("cpu") if device is None else torch.device(device)
Expand Down Expand Up @@ -111,12 +111,15 @@ def __call__(
if isinstance(bits, int):
bits = rng.rand(bits) > 0.5

# handle fractional sampling:
temp_samplerate = max(
int(8 * cutoff_frequency + 0.5) // int(samplerate + 0.5) * samplerate,
samplerate,
)
rc = int(temp_samplerate + 0.5) // int(samplerate + 0.5)
rc = 1
temp_samplerate = samplerate
if cutoff_frequency is not None:
# handle fractional sampling:
temp_samplerate = max(
int(8 * cutoff_frequency + 0.5) // int(samplerate + 0.5) * samplerate,
samplerate,
)
rc = int(temp_samplerate + 0.5) // int(samplerate + 0.5)
rates_gcd = np.gcd(int(temp_samplerate + 0.5), int(bitrate + 0.5))
rs = int(temp_samplerate + 0.5) // rates_gcd
rb = int(bitrate + 0.5) // rates_gcd
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,4 @@ def conn():
d = pt.Detector()
conn = wg["ab"] * s["a"] * d["b"]
return conn

21 changes: 21 additions & 0 deletions tests/test_torch_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,27 @@ def test_buffer_repr():
s = repr(pt.nn.Buffer(data=torch.tensor([0.5])))
assert s.startswith("Buffer")

def test_ber():
berfunc = pt.BERLoss(bitrate=50e9, samplerate=160e9) # uneven sample rate
streamgenerator = pt.BitStreamGenerator(bitrate=50e9, samplerate=160e9)
output_bits = np.array([1,0,0,1,1,0,0,1,0,1])
target_bits = np.array([1,0,1,1,1,0,0,1,0,1]) # one bit difference
output_stream = streamgenerator(output_bits)
target_stream = streamgenerator(target_bits)
ber = berfunc(output_stream, target_stream)
assert ber == 0.1

def test_mse():
msefunc = pt.MSELoss(bitrate=50e9, samplerate=160e9) # uneven sample rate
streamgenerator = pt.BitStreamGenerator(bitrate=50e9, samplerate=160e9)
output_bits = np.array([1,0,0,1,1,0,0,1,0,1])
target_bits = np.array([1,0,1,1,1,0,0,1,0,1]) # one bit difference
output_stream = streamgenerator(output_bits).requires_grad_()
target_stream = streamgenerator(target_bits)
mse = msefunc(output_stream, target_stream)
assert np.allclose(mse.item(), 0.09375)
assert mse.requires_grad


###############
## Run Tests ##
Expand Down

0 comments on commit 65e1e8b

Please sign in to comment.