diff --git a/README.md b/README.md index b976f82..944a232 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ the community. This is the place to do it. | dclib | internal | Guy Hutchison | Utility components for DecoupledIO interfaces | | ecc | internal | Guy Hutchison | Hamming Error-Correcting code modules | | iir | internal | Kevin Joly | Infinite Impulse Response filter module | +| fir | internal | Kevin Joly | Finite Impulse Response filter module | ### Using ip-contributions diff --git a/src/main/scala/chisel/lib/firfilter/FIRFilter.scala b/src/main/scala/chisel/lib/firfilter/FIRFilter.scala new file mode 100644 index 0000000..af037b8 --- /dev/null +++ b/src/main/scala/chisel/lib/firfilter/FIRFilter.scala @@ -0,0 +1,148 @@ +/* + * + * A fixed point FIR filter module. + * + * Author: Kevin Joly (kevin.joly@armadeus.com) + * + */ + +package chisel.lib.firfilter + +import chisel3._ +import chisel3.experimental.ChiselEnum +import chisel3.util._ + +/* + * FIR filter module + * + * Apply filter on input samples passed by ready/valid handshake. Coefficients + * are to be set prior to push any input sample. + * + * All the computations are done in fixed point. Output width is inputWidth + + * coefWidth + log2Ceil(coefNum). + * + */ +class FIRFilter( + inputWidth: Int, + coefWidth: Int, + coefDecimalWidth: Int, + coefNum: Int) + extends Module { + + val outputWidth = inputWidth + coefWidth + log2Ceil(coefNum) + + val io = IO(new Bundle { + /* + * Input samples + */ + val input = Flipped(Decoupled(SInt(inputWidth.W))) + /* + * Filter's coefficients b[0], b[1], ... + */ + val coef = Input(Vec(coefNum, SInt(coefWidth.W))) + /* + * Filtered samples. Fixed point format is: + * (inputWidth+coefWidth).coefDecimalWidth + * Thus, output should be right shifted to the right of 'coefDecimalWidth' bits. + */ + val output = Decoupled(SInt(outputWidth.W)) + }) + + assert(coefWidth >= coefDecimalWidth) + + val coefIdx = RegInit(0.U(coefNum.W)) + + object FIRFilterState extends ChiselEnum { + val Idle, Compute, Valid, LeftOver = Value + } + + val state = RegInit(FIRFilterState.Idle) + + switch(state) { + is(FIRFilterState.Idle) { + when(io.input.valid) { + state := FIRFilterState.Compute + } + } + is(FIRFilterState.Compute) { + when(coefIdx === (coefNum - 1).U) { + state := FIRFilterState.LeftOver + } + } + is(FIRFilterState.LeftOver) { + state := FIRFilterState.Valid + } + is(FIRFilterState.Valid) { + when(io.output.ready) { + state := FIRFilterState.Idle + } + } + } + + when((state === FIRFilterState.Idle) && io.input.valid) { + coefIdx := 1.U + }.elsewhen(state === FIRFilterState.Compute) { + when(coefIdx === (coefNum - 1).U) { + coefIdx := 0.U + }.otherwise { + coefIdx := coefIdx + 1.U + } + }.otherwise { + coefIdx := 0.U + } + + val inputReg = RegInit(0.S(inputWidth.W)) + val inputMem = Mem(coefNum - 1, SInt(inputWidth.W)) + val inputMemAddr = RegInit(0.U(math.max(log2Ceil(coefNum - 1), 1).W)) + val inputMemOut = Wire(SInt(inputWidth.W)) + val inputRdWr = inputMem(inputMemAddr) + + inputMemOut := DontCare + + when(state === FIRFilterState.LeftOver) { + inputRdWr := inputReg + }.elsewhen((state === FIRFilterState.Idle) && io.input.valid) { + inputReg := io.input.bits // Delayed write + inputMemOut := inputRdWr + }.otherwise { + inputMemOut := inputRdWr + } + + when((state === FIRFilterState.Compute) && (coefIdx < (coefNum - 1).U)) { + when(inputMemAddr === (coefNum - 2).U) { + inputMemAddr := 0.U + }.otherwise { + inputMemAddr := inputMemAddr + 1.U + } + } + + val inputSum = RegInit(0.S(outputWidth.W)) + + val multNumOut = Wire(SInt((inputWidth + coefWidth).W)) + val multNumOutReg = RegInit(0.S((inputWidth + coefWidth).W)) + val multNumIn = Wire(SInt(inputWidth.W)) + + when((state === FIRFilterState.Idle) && io.input.valid) { + multNumOutReg := multNumOut + inputSum := 0.S + }.elsewhen(state === FIRFilterState.Compute) { + when(coefIdx < coefNum.U) { + multNumOutReg := multNumOut + inputSum := inputSum +& multNumOutReg + } + }.elsewhen(state === FIRFilterState.LeftOver) { + inputSum := inputSum +& multNumOutReg + } + + when(state === FIRFilterState.Idle) { + multNumIn := io.input.bits + }.otherwise { + multNumIn := inputMemOut + } + + multNumOut := multNumIn * io.coef(coefIdx) + + io.input.ready := state === FIRFilterState.Idle + io.output.valid := state === FIRFilterState.Valid + io.output.bits := inputSum +} diff --git a/src/main/scala/chisel/lib/firfilter/README.md b/src/main/scala/chisel/lib/firfilter/README.md new file mode 100644 index 0000000..f9f7290 --- /dev/null +++ b/src/main/scala/chisel/lib/firfilter/README.md @@ -0,0 +1,7 @@ +# FIR Filter + +Simple fixed point FIR filter with pipelined computation. + +Tests are run as follow: +```sbt "testOnly chisel.lib.firfilter.SimpleFIRFilterTest -- -DwriteVcd=1"``` +```sbt "testOnly chisel.lib.firfilter.RandomSignalTest -- -DwriteVcd=1"``` diff --git a/src/main/scala/chisel/lib/iirfilter/iirfilter.scala b/src/main/scala/chisel/lib/iirfilter/IIRFilter.scala similarity index 100% rename from src/main/scala/chisel/lib/iirfilter/iirfilter.scala rename to src/main/scala/chisel/lib/iirfilter/IIRFilter.scala diff --git a/src/test/scala/chisel/lib/firfilter/RandomSignalTest.scala b/src/test/scala/chisel/lib/firfilter/RandomSignalTest.scala new file mode 100644 index 0000000..19c5751 --- /dev/null +++ b/src/test/scala/chisel/lib/firfilter/RandomSignalTest.scala @@ -0,0 +1,132 @@ +/* + * Filter a random signal using FIRFilter module and compare with the expected output. + * + * See README.md for license details. + */ + +package chisel.lib.firfilter + +import chisel3._ +import chisel3.experimental.VecLiterals._ +import chisel3.util.log2Ceil +import chiseltest._ + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import scala.util.Random + +trait FIRFilterBehavior { + + this: AnyFlatSpec with ChiselScalatestTester => + + def testFilter( + inputWidth: Int, + inputDecimalWidth: Int, + coefWidth: Int, + coefDecimalWidth: Int, + coefs: Seq[Int], + inputData: Seq[Int], + expectedOutput: Seq[Double], + precision: Double + ): Unit = { + + it should "work" in { + test( + new FIRFilter( + inputWidth = inputWidth, + coefWidth = coefWidth, + coefDecimalWidth = coefDecimalWidth, + coefNum = coefs.length + ) + ) { dut => + dut.io.coef.poke(Vec.Lit(coefs.map(_.S(coefWidth.W)): _*)) + + dut.io.output.ready.poke(true.B) + + for ((d, e) <- (inputData.zip(expectedOutput))) { + + dut.io.input.ready.expect(true.B) + + // Push input sample + dut.io.input.bits.poke(d.S(inputWidth.W)) + dut.io.input.valid.poke(true.B) + + dut.clock.step(1) + + dut.io.input.valid.poke(false.B) + + for (i <- 0 until coefs.length) { + dut.io.output.valid.expect(false.B) + dut.io.input.ready.expect(false.B) + dut.clock.step(1) + } + + // Check output + val outputDecimalWidth = coefDecimalWidth + inputDecimalWidth + val output = dut.io.output.bits.peek().litValue.toFloat / math.pow(2, outputDecimalWidth).toFloat + val upperBound = e + precision + val lowerBound = e - precision + + assert(output < upperBound) + assert(output > lowerBound) + + dut.io.output.valid.expect(true.B) + + dut.clock.step(1) + } + } + } + } +} + +class RandomSignalTest extends AnyFlatSpec with FIRFilterBehavior with ChiselScalatestTester with Matchers { + + def computeExpectedOutput(coefs: Seq[Double], inputData: Seq[Double]): Seq[Double] = { + return for (i <- 0 until inputData.length) yield { + val inputSum = (for (j <- i until math.max(i - coefs.length, -1) by -1) yield { + inputData(j) * coefs(i - j) + }).reduce(_ + _) + + inputSum + } + } + + behavior.of("FIRFilter") + + Random.setSeed(11340702) + + // 9 taps Kaiser high-pass filter 50Hz (sampling freq: 44.1kHz) + val coefs = Seq(-0.00227242, -0.00227255, -0.00227265, -0.00227271, 0.99999962, -0.00227271, -0.00227265, -0.00227255, + -0.00227242) + + // Setup data width + val inputWidth = 16 + val inputDecimalWidth = 12 + + val coefWidth = 32 + val coefDecimalWidth = 28 + + // Generate random input data [-1., 1.] + val inputData = Seq.fill(100)(-1.0 + Random.nextDouble * 2.0) + + // Compute expected outputs + val expectedOutput = computeExpectedOutput(coefs, inputData) + + // Floating point to fixed point data + val coefsInt = for (n <- coefs) yield { (n * math.pow(2, coefDecimalWidth)).toInt } + val inputDataInt = for (x <- inputData) yield (x * math.pow(2, inputDecimalWidth)).toInt + + (it should behave).like( + testFilter( + inputWidth, + inputDecimalWidth, + coefWidth, + coefDecimalWidth, + coefsInt, + inputDataInt, + expectedOutput, + 0.0005 + ) + ) +} diff --git a/src/test/scala/chisel/lib/firfilter/SimpleFIRFilterTest.scala b/src/test/scala/chisel/lib/firfilter/SimpleFIRFilterTest.scala new file mode 100644 index 0000000..15760bd --- /dev/null +++ b/src/test/scala/chisel/lib/firfilter/SimpleFIRFilterTest.scala @@ -0,0 +1,165 @@ +/* + * A very simple test collection for FIRFilter module. + * + * See README.md for license details. + */ + +package chisel.lib.firfilter + +import chisel3._ +import chisel3.experimental.VecLiterals._ +import chisel3.util.log2Ceil +import chiseltest._ + +import org.scalatest.flatspec.AnyFlatSpec + +class FIRFilterCoefTest extends AnyFlatSpec with ChiselScalatestTester { + "FIRFilter coef" should "work" in { + + val inputWidth = 4 + val coefWidth = 3 + val coefDecimalWidth = 0 + val coefs = Seq(2, 1, 0, 3) + + test( + new FIRFilter( + inputWidth = inputWidth, + coefWidth = coefWidth, + coefDecimalWidth = coefDecimalWidth, + coefNum = coefs.length + ) + ) { dut => + dut.io.coef.poke(Vec.Lit(coefs.map(_.S(coefWidth.W)): _*)) + + dut.io.output.ready.poke(true.B) + + // Sample 1: Write 1. on input port + dut.io.input.bits.poke(1.S) + dut.io.input.valid.poke(true.B) + dut.io.input.ready.expect(true.B) + dut.io.output.valid.expect(false.B) + dut.clock.step(1) + dut.io.input.valid.poke(false.B) + dut.io.input.ready.expect(false.B) + + for (i <- 0 until coefs.length) { + dut.io.output.valid.expect(false.B) + dut.io.input.ready.expect(false.B) + dut.clock.step(1) + } + + dut.io.output.bits.expect(2.S) + dut.io.output.valid.expect(true.B) + + dut.clock.step(1) + + // Sample 2: Write 1. on input port + dut.io.input.bits.poke(1.S) + dut.io.input.valid.poke(true.B) + dut.io.input.ready.expect(true.B) + dut.io.output.valid.expect(false.B) + dut.clock.step(1) + dut.io.input.valid.poke(false.B) + + for (i <- 0 until coefs.length) { + dut.io.output.valid.expect(false.B) + dut.io.input.ready.expect(false.B) + dut.clock.step(1) + } + + dut.io.output.bits.expect(3.S) + dut.io.output.valid.expect(true.B) + + dut.clock.step(1) + + // Sample 3: Write 0. on input port + dut.io.input.bits.poke(0.S) + dut.io.input.valid.poke(true.B) + dut.io.input.ready.expect(true.B) + dut.io.output.valid.expect(false.B) + dut.clock.step(1) + dut.io.input.valid.poke(false.B) + + for (i <- 0 until coefs.length) { + dut.io.output.valid.expect(false.B) + dut.io.input.ready.expect(false.B) + dut.clock.step(1) + } + + dut.io.output.bits.expect(1.S) + dut.io.output.valid.expect(true.B) + + dut.clock.step(1) + + // Sample 4: Write 0. on input port + dut.io.input.bits.poke(0.S) + dut.io.input.valid.poke(true.B) + dut.io.input.ready.expect(true.B) + dut.io.output.valid.expect(false.B) + dut.clock.step(1) + dut.io.input.valid.poke(false.B) + + for (i <- 0 until coefs.length) { + dut.io.output.valid.expect(false.B) + dut.io.input.ready.expect(false.B) + dut.clock.step(1) + } + + dut.io.output.bits.expect(3.S) + dut.io.output.valid.expect(true.B) + } + } +} + +class FIRFilterReadyTest extends AnyFlatSpec with ChiselScalatestTester { + "FIRFilter" should "work" in { + + val inputWidth = 4 + val coefWidth = 3 + val coefDecimalWidth = 0 + val coefs = Seq(1, 2, 0) + + test( + new FIRFilter( + inputWidth = inputWidth, + coefWidth = coefWidth, + coefDecimalWidth = coefDecimalWidth, + coefNum = coefs.length + ) + ) { dut => + dut.io.coef.poke(Vec.Lit(coefs.map(_.S(coefWidth.W)): _*)) + + dut.io.output.ready.poke(false.B) + + // Sample 1: Write 1. on input port + dut.io.input.bits.poke(1.S) + dut.io.input.valid.poke(true.B) + dut.io.input.ready.expect(true.B) + dut.io.output.valid.expect(false.B) + dut.clock.step(1) + dut.io.input.valid.poke(false.B) + dut.io.input.ready.expect(false.B) + + for (i <- 0 until coefs.length) { + dut.io.output.valid.expect(false.B) + dut.io.input.ready.expect(false.B) + dut.clock.step(1) + } + + val extraClockCycles = 10 + for (i <- 0 until extraClockCycles) { + dut.io.output.valid.expect(true.B) + dut.io.input.ready.expect(false.B) + dut.clock.step(1) + } + + dut.io.output.ready.poke(true.B) + + dut.clock.step(1) + + dut.io.output.bits.expect(1.S) + dut.io.output.valid.expect(false.B) + dut.io.input.ready.expect(true.B) + } + } +}