gateware: fix HalfBandInterpolator backpressure issues

This commit is contained in:
mndza
2026-01-09 16:23:11 +01:00
parent 29bfc3b78b
commit 76318321f5
2 changed files with 167 additions and 101 deletions

Binary file not shown.

View File

@@ -7,7 +7,7 @@
from math import ceil, log2
from amaranth import Module, Signal, Mux, DomainRenamer
from amaranth.lib import wiring, stream, data, memory
from amaranth.lib import wiring, stream, data, memory, fifo
from amaranth.lib.wiring import In, Out
from amaranth.utils import bits_for
@@ -173,9 +173,17 @@ class HalfBandInterpolator(wiring.Component):
delay = arm1_taps.index(1)
# Arms
m.submodules.fir0 = fir0 = FIRFilter(arm0_taps, shape=self.data_shape, shape_out=self.shape_out, always_ready=always_ready, num_channels=self.num_channels)
m.submodules.fir1 = fir1 = Delay(delay, shape=self.data_shape, always_ready=always_ready, num_channels=self.num_channels)
arms = [fir0, fir1]
m.submodules.fir = fir = FIRFilter(arm0_taps, shape=self.data_shape, shape_out=self.shape_out, always_ready=always_ready, num_channels=self.num_channels)
m.submodules.dly = dly = Delay(delay, shape=self.data_shape, always_ready=always_ready, num_channels=self.num_channels)
m.submodules.dly_fifo = dly_fifo = fifo.SyncFIFOBuffered(width=self.num_channels*self.data_shape.as_shape().width, depth=1)
arms = [fir, dly]
m.d.comb += [
dly_fifo.w_data.eq(dly.output.p),
dly_fifo.w_en.eq(dly.output.valid),
]
if not dly.output.signature.always_ready:
m.d.comb += dly.output.ready.eq(dly_fifo.w_rdy)
with m.FSM():
@@ -198,7 +206,6 @@ class HalfBandInterpolator(wiring.Component):
m.next = "BYPASS"
# Input
for i, arm in enumerate(arms):
m.d.comb += arm.input.payload.eq(self.input.payload)
m.d.comb += arm.input.valid.eq(self.input.valid & arms[i^1].input.ready)
@@ -211,29 +218,25 @@ class HalfBandInterpolator(wiring.Component):
arm_index = Signal()
# Output buffers for each arm.
arm_outputs = [arm.output for arm in arms]
if self.output.signature.always_ready:
buffers = [stream.Signature(arm.payload.shape()).create() for arm in arm_outputs]
for arm, buf in zip(arm_outputs, buffers):
with m.If(~buf.valid | buf.ready):
if not arm.signature.always_ready:
m.d.comb += arm.ready.eq(1)
m.d.sync += buf.valid.eq(arm.valid)
with m.If(arm.valid):
m.d.sync += buf.payload.eq(arm.payload)
arm_outputs = buffers
r_data_cast = data.ArrayLayout(self.data_shape, self.num_channels)(dly_fifo.r_data)
with m.If(~self.output.valid | self.output.ready):
with m.Switch(arm_index):
for i, arm in enumerate(arm_outputs):
with m.Case(i):
for c in range(self.num_channels):
m.d.sync += self.output.payload[c].eq(arm.payload[c])
m.d.sync += self.output.valid.eq(arm.valid)
if not arm.signature.always_ready:
m.d.comb += arm.ready.eq(1)
with m.If(arm.valid):
m.d.sync += arm_index.eq(arm_index ^ 1)
with m.Case(0):
for c in range(self.num_channels):
m.d.sync += self.output.payload[c].eq(fir.output.payload[c])
m.d.sync += self.output.valid.eq(fir.output.valid)
if not fir.output.signature.always_ready:
m.d.comb += fir.output.ready.eq(1)
with m.If(fir.output.valid):
m.d.sync += arm_index.eq(1)
with m.Case(1):
for c in range(self.num_channels):
m.d.sync += self.output.payload[c].eq(r_data_cast[c])
m.d.sync += self.output.valid.eq(dly_fifo.r_rdy)
m.d.comb += dly_fifo.r_en.eq(1)
with m.If(dly_fifo.r_rdy):
m.d.sync += arm_index.eq(0)
if self._domain != "sync":
m = DomainRenamer(self._domain)(m)
@@ -439,24 +442,26 @@ class _TestFilter(unittest.TestCase):
return samples / (1 << f_width)
return samples
def _filter(self, dut, samples, count, num_channels=1, outfile=None, empty_cycles=0):
def _filter(self, dut, samples, count, num_channels=1, outfile=None, empty_cycles=0, empty_ready_cycles=0):
async def input_process(ctx):
if hasattr(dut, "enable"):
ctx.set(dut.enable, 1)
await ctx.tick()
ctx.set(dut.input.valid, 1)
for sample in samples:
await ctx.tick()
for i, sample in enumerate(samples):
if num_channels > 1:
ctx.set(dut.input.payload, [s.item() for s in sample])
else:
ctx.set(dut.input.payload, [sample.item()])
if isinstance(dut.input.payload.shape(), data.ArrayLayout):
ctx.set(dut.input.payload, [sample.item()])
else:
ctx.set(dut.input.payload, sample.item())
ctx.set(dut.input.valid, 1)
await ctx.tick().until(dut.input.ready)
ctx.set(dut.input.valid, 0)
if empty_cycles > 0:
ctx.set(dut.input.valid, 0)
await ctx.tick().repeat(empty_cycles)
ctx.set(dut.input.valid, 1)
ctx.set(dut.input.valid, 0)
filtered = []
async def output_process(ctx):
@@ -467,7 +472,14 @@ class _TestFilter(unittest.TestCase):
if num_channels > 1:
filtered.append([v.as_float() for v in payload])
else:
filtered.append(payload[0].as_float())
if isinstance(payload.shape(), data.ArrayLayout):
filtered.append(payload[0].as_float())
else:
filtered.append(payload.as_float())
if empty_ready_cycles > 0:
ctx.set(dut.output.ready, 0)
await ctx.tick().repeat(empty_ready_cycles)
ctx.set(dut.output.ready, 1)
if not dut.output.signature.always_ready:
ctx.set(dut.output.ready, 0)
@@ -498,100 +510,154 @@ class TestFIRFilter(_TestFilter):
filtered_np = np.convolve(input_samples, taps).tolist()
# Simulate DUT
dut = FIRFilter(taps, fixed.SQ(15, 0), always_ready=True)
filtered = self._filter(dut, input_samples, len(input_samples))
dut = FIRFilter(taps, shape=fixed.SQ(8, 0), always_ready=False)
filtered = self._filter(dut, input_samples, len(input_samples), empty_ready_cycles=5)
self.assertListEqual(filtered_np[:len(filtered)], filtered)
class TestHalfBandDecimator(_TestFilter):
def test_filter_no_backpressure(self):
taps = [-1, 0, 9, 16, 9, 0, -1]
taps = [ tap / 32 for tap in taps ]
def test_filter(self):
num_samples = 1024
input_width = 8
samples_i_in = self._generate_samples(num_samples, input_width, f_width=7)
samples_q_in = self._generate_samples(num_samples, input_width, f_width=7)
common_dut_options = dict(
data_shape=fixed.SQ(7),
shape_out=fixed.SQ(0,31),
)
# Compute the expected result
filtered_i_np = np.convolve(samples_i_in, taps)[1::2].tolist()
filtered_q_np = np.convolve(samples_q_in, taps)[1::2].tolist()
taps0 = (np.array([-1, 0, 9, 16, 9, 0, -1]) / 32).tolist()
taps1 = (np.array([-2, 0, 7, 0, -18, 0, 41, 0, -92, 0, 320, 512, 320, 0, -92, 0, 41, 0, -18, 0, 7, 0, -2]) / 1024).tolist()
# Simulate DUT
dut = HalfBandDecimator(taps, data_shape=fixed.SQ(7), shape_out=fixed.SQ(0,16), always_ready=True)
filtered = self._filter(dut, zip(samples_i_in, samples_q_in), len(samples_i_in) // 2, num_channels=2)
filtered_i = [ x[0] for x in filtered ]
filtered_q = [ x[1] for x in filtered ]
self.assertListEqual(filtered_i_np[:len(filtered_i)], filtered_i)
self.assertListEqual(filtered_q_np[:len(filtered_q)], filtered_q)
inputs = {
def test_filter_with_spare_cycles(self):
taps = [-1, 0, 9, 16, 9, 0, -1]
taps = [ tap / 32 for tap in taps ]
"test_filter_with_backpressure": {
"num_samples": 1024,
"dut_options": dict(**common_dut_options, always_ready=False, taps=taps0),
"sim_opts": dict(empty_cycles=0),
},
num_samples = 1024
input_width = 8
samples_i_in = self._generate_samples(num_samples, input_width, f_width=7)
samples_q_in = self._generate_samples(num_samples, input_width, f_width=7)
"test_filter_with_backpressure_and_empty_cycles": {
"num_samples": 1024,
"dut_options": dict(**common_dut_options, always_ready=False, taps=taps0),
"sim_opts": dict(empty_cycles=3),
},
# Compute the expected result
filtered_i_np = np.convolve(samples_i_in, taps)[1::2].tolist()
filtered_q_np = np.convolve(samples_q_in, taps)[1::2].tolist()
"test_filter_with_backpressure_taps1": {
"num_samples": 1024,
"dut_options": dict(**common_dut_options, always_ready=False, taps=taps1),
"sim_opts": dict(empty_cycles=0),
},
# Simulate DUT
dut = HalfBandDecimator(taps, data_shape=fixed.SQ(7), shape_out=fixed.SQ(0,16), always_ready=True)
filtered = self._filter(dut, zip(samples_i_in, samples_q_in), len(samples_i_in) // 2, num_channels=2, empty_cycles=3)
filtered_i = [ x[0] for x in filtered ]
filtered_q = [ x[1] for x in filtered ]
"test_filter_no_backpressure_and_empty_cycles_taps1": {
"num_samples": 1024,
"dut_options": dict(**common_dut_options, always_ready=True, taps=taps0),
"sim_opts": dict(empty_cycles=6),
},
self.assertListEqual(filtered_i_np[:len(filtered_i)], filtered_i)
self.assertListEqual(filtered_q_np[:len(filtered_q)], filtered_q)
"test_filter_no_backpressure": {
"num_samples": 1024,
"dut_options": dict(**common_dut_options, always_ready=True, taps=taps1),
"sim_opts": dict(empty_cycles=3),
},
}
for name, scenario in inputs.items():
def test_filter_with_backpressure(self):
taps = [-1, 0, 9, 16, 9, 0, -1]
taps = [ tap / 32 for tap in taps ]
with self.subTest(name):
taps = scenario["dut_options"]["taps"]
num_samples = scenario["num_samples"]
num_samples = 1024
input_width = 8
samples_i_in = self._generate_samples(num_samples, input_width, f_width=7)
samples_q_in = self._generate_samples(num_samples, input_width, f_width=7)
input_width = 8
samples_i_in = self._generate_samples(num_samples, input_width, f_width=7)
samples_q_in = self._generate_samples(num_samples, input_width, f_width=7)
# Compute the expected result
filtered_i_np = np.convolve(samples_i_in, taps)[1::2].tolist()
filtered_q_np = np.convolve(samples_q_in, taps)[1::2].tolist()
# Compute the expected result
filtered_i_np = np.convolve(samples_i_in, taps)[1::2].tolist()
filtered_q_np = np.convolve(samples_q_in, taps)[1::2].tolist()
# Simulate DUT
dut = HalfBandDecimator(taps, data_shape=fixed.SQ(7), shape_out=fixed.SQ(0,16), always_ready=False)
filtered = self._filter(dut, zip(samples_i_in, samples_q_in), len(samples_i_in) // 2, num_channels=2)
filtered_i = [ x[0] for x in filtered ]
filtered_q = [ x[1] for x in filtered ]
# Simulate DUT
dut = HalfBandDecimator(**scenario["dut_options"])
filtered = self._filter(dut, zip(samples_i_in, samples_q_in), len(samples_i_in) // 2, num_channels=2, **scenario["sim_opts"])
filtered_i = [ x[0] for x in filtered ]
filtered_q = [ x[1] for x in filtered ]
self.assertListEqual(filtered_i_np[:len(filtered_i)], filtered_i)
self.assertListEqual(filtered_q_np[:len(filtered_q)], filtered_q)
self.assertListEqual(filtered_i_np[:len(filtered_i)], filtered_i)
self.assertListEqual(filtered_q_np[:len(filtered_q)], filtered_q)
class TestHalfBandInterpolator(_TestFilter):
def test_filter(self):
taps = [-1, 0, 9, 16, 9, 0, -1]
taps = [ tap / 32 for tap in taps ]
num_samples = 1024
input_width = 8
input_samples = self._generate_samples(num_samples, input_width, f_width=7)
# Compute the expected result
input_samples_pad = np.zeros(2*len(input_samples))
input_samples_pad[0::2] = 2*input_samples # pad with zeros, adjust gain
filtered_np = np.convolve(input_samples_pad, taps).tolist()
common_dut_options = dict(
data_shape=fixed.SQ(7),
shape_out=fixed.SQ(1,16),
)
# Simulate DUT
dut = HalfBandInterpolator(taps, data_shape=fixed.SQ(0, 7), shape_out=fixed.SQ(0,16), always_ready=False)
filtered = self._filter(dut, input_samples, len(input_samples) * 2)
taps0 = (np.array([-1, 0, 9, 16, 9, 0, -1]) / 32).tolist()
taps1 = (np.array([-2, 0, 7, 0, -18, 0, 41, 0, -92, 0, 320, 512, 320, 0, -92, 0, 41, 0, -18, 0, 7, 0, -2]) / 1024).tolist()
self.assertListEqual(filtered_np[:len(filtered)], filtered)
inputs = {
"test_filter_with_backpressure": {
"num_samples": 1024,
"dut_options": dict(**common_dut_options, always_ready=False, num_channels=2, taps=taps1),
"sim_opts": dict(empty_cycles=0, empty_ready_cycles=0),
},
"test_filter_with_backpressure_and_empty_cycles": {
"num_samples": 1024,
"dut_options": dict(**common_dut_options, num_channels=2, always_ready=False, taps=taps0),
"sim_opts": dict(empty_ready_cycles=7, empty_cycles=3),
},
"test_filter_with_backpressure_taps1": {
"num_samples": 1024,
"dut_options": dict(**common_dut_options, num_channels=2, always_ready=False, taps=taps1),
"sim_opts": dict(empty_ready_cycles=7, empty_cycles=0),
},
"test_filter_no_backpressure_and_empty_cycles_taps1": {
"num_samples": 1024,
"dut_options": dict(**common_dut_options, num_channels=2, always_ready=True, taps=taps0),
"sim_opts": dict(empty_cycles=8),
},
"test_filter_no_backpressure": {
"num_samples": 1024,
"dut_options": dict(**common_dut_options, num_channels=2, always_ready=True, taps=taps1),
"sim_opts": dict(empty_cycles=16),
},
}
for name, scenario in inputs.items():
with self.subTest(name):
taps = scenario["dut_options"]["taps"]
num_samples = scenario["num_samples"]
input_width = 8
samples_i_in = self._generate_samples(num_samples, input_width, f_width=7)
samples_q_in = self._generate_samples(num_samples, input_width, f_width=7)
# Compute the expected result
input_samples_pad = np.zeros(2*len(samples_i_in))
input_samples_pad[0::2] = 2*samples_i_in # pad with zeros, adjust gain
filtered_i_np = np.convolve(input_samples_pad, taps).tolist()
input_samples_pad = np.zeros(2*len(samples_q_in))
input_samples_pad[0::2] = 2*samples_q_in # pad with zeros, adjust gain
filtered_q_np = np.convolve(input_samples_pad, taps).tolist()
# Simulate DUT
dut = HalfBandInterpolator(**scenario["dut_options"])
filtered = self._filter(dut, zip(samples_i_in, samples_q_in), len(samples_i_in) * 2, num_channels=2, **scenario["sim_opts"])
filtered_i = [ x[0] for x in filtered ]
filtered_q = [ x[1] for x in filtered ]
self.assertListEqual(filtered_i_np[:len(filtered_i)], filtered_i)
self.assertListEqual(filtered_q_np[:len(filtered_q)], filtered_q)
if __name__ == "__main__":
unittest.main()
unittest.main()