Compute Core Modules

RTL source on GitHub

SystemVerilog sources documented on this page:

1. Matrix Core — Systolic Top

GEMM_systolic_top.sv wraps the 32 × 32 systolic array (cascade split at row 16 into two 32 × 16 sub-chains). It receives weight tiles from HP0/HP1 and activation rows from the L2 cache, and streams accumulated results to the post-processor.

Listing 10 hw/rtl/MAT_CORE/GEMM_systolic_top.sv
`timescale 1ns / 1ps

`include "GLOBAL_CONST.svh"
`include "GEMM_Array.svh"

/**
 * Module: GEMM_systolic_top
 * Target: Kria KV260 @ 400MHz
 *
 * Architecture V2:
 * - Weight Dispatcher (Unpacker)
 * - Staggered Delay Lines for FMap & Instructions
 * - 32x32 Systolic Array Core
 * - e_max Pipe for Synchronization with Result Output
 */

module GEMM_systolic_top #(
    parameter weight_lane_cnt      = `HP_PORT_CNT,
    parameter weight_width_per_lane = `HP_PORT_SINGLE_WIDTH,
    parameter weight_size          = `INT4_WIDTH,

    // 32 = 128 bit / int4 (4-bit)
    parameter weight_cnt           = `HP_PORT_SINGLE_WIDTH / `INT4_WIDTH,

    parameter array_horizontal     = `ARRAY_SIZE_H,
    parameter array_vertical       = `ARRAY_SIZE_V,

    // v001 fmap was BF16 mantissa on DSP A-port. v002 will replace this
    // with INT8 on B-port; the staggered delay still carries the v001
    // width until the PREPROCESS stage is ported.
    parameter dsp_A_port           = `DEVICE_DSP_A_WIDTH,
    parameter IN_fmap_brodcast     = `FIXED_MANT_WIDTH
)(
    input logic clk,
    input logic rst_n,
    input logic i_clear,

    // Control & Inst
    input logic global_weight_valid,
    input logic [2:0] global_inst,
    input logic global_inst_valid,

    // Feature Map Broadcast (from SRAM Cache)
    input logic [IN_fmap_brodcast-1:0] IN_fmap_broadcast      [0:`ARRAY_SIZE_H-1],
    input logic                        IN_fmap_broadcast_valid,

    // e_max (from Cache for Normalization alignment)
    input logic [`BF16_EXP_WIDTH-1:0]  IN_cached_emax_out[0:`ARRAY_SIZE_H-1],

    // ===| Weight input lanes |===================================================
    //   HP0 -> upper INT4 channel, HP1 -> lower INT4 channel. Both lanes must
    //   present valid data in the same cycle for the W4A8 dual-MAC pipeline.
    //   Arrays are already unpacked to 32 × INT4 upstream (128 bit / 4 bit = 32).
    input  logic [`INT4_WIDTH-1:0] IN_weight_upper      [0:(`HP_PORT_SINGLE_WIDTH/`INT4_WIDTH)-1],
    input  logic                   IN_weight_upper_valid,
    output logic                   IN_weight_upper_ready,
    input  logic [`INT4_WIDTH-1:0] IN_weight_lower      [0:(`HP_PORT_SINGLE_WIDTH/`INT4_WIDTH)-1],
    input  logic                   IN_weight_lower_valid,
    output logic                   IN_weight_lower_ready,

    // Output Results (Raw)
    output logic [`DSP48E2_POUT_SIZE-1:0] raw_res_sum      [0:`ARRAY_SIZE_H-1],
    output logic                          raw_res_sum_valid[0:`ARRAY_SIZE_H-1],

    // Delayed e_max for Normalizers
    output logic [`BF16_EXP_WIDTH-1:0] delayed_emax_32[0:`ARRAY_SIZE_H-1]
);

  // ===| Weight Dispatcher (dual-lane pipeline FF) |============================
  logic [weight_size-1:0] weight_upper [0:weight_cnt-1];
  logic [weight_size-1:0] weight_lower [0:weight_cnt-1];
  logic                   weights_ready_for_array;

  GEMM_weight_dispatcher #(
    .weight_size(weight_size),
    .weight_cnt (weight_cnt)
  ) u_weight_unpacker (
      .clk  (clk),
      .rst_n(rst_n),

      .fifo_upper      (IN_weight_upper),
      .fifo_upper_valid(IN_weight_upper_valid),
      .fifo_upper_ready(IN_weight_upper_ready),

      .fifo_lower      (IN_weight_lower),
      .fifo_lower_valid(IN_weight_lower_valid),
      .fifo_lower_ready(IN_weight_lower_ready),

      .weight_upper(weight_upper),
      .weight_lower(weight_lower),
      .weight_valid(weights_ready_for_array)
  );

  // ===| Staggered Delay Line for FMap & Instructions |=======
  logic [dsp_A_port-1:0] staggered_fmap      [0:`ARRAY_SIZE_H-1];
  logic                  staggered_fmap_valid[0:`ARRAY_SIZE_H-1];
  logic [           2:0] staggered_inst      [0:`ARRAY_SIZE_H-1];
  logic                  staggered_inst_valid[0:`ARRAY_SIZE_H-1];

  GEMM_fmap_staggered_dispatch #(
      .fmap_width(IN_fmap_brodcast),
      .array_size(array_vertical),
      .fmap_out_width(dsp_A_port)
  ) u_delay_line (
      .clk(clk),
      .rst_n(rst_n),
      .fmap_in(IN_fmap_broadcast),
      .fmap_valid(IN_fmap_broadcast_valid),
      .global_inst(global_inst),
      .global_inst_valid(global_inst_valid),
      .row_data(staggered_fmap),
      .row_valid(staggered_fmap_valid),
      .row_inst(staggered_inst),
      .row_inst_valid(staggered_inst_valid)
  );

  // ===| Systolic Array Core (The Engine) |=======
  logic [`DSP48E2_POUT_SIZE-1:0] raw_res_seq[0:`ARRAY_SIZE_H-1];

  // TODO(pccx v002 §2.2 follow-up):
  //   * Fmap must arrive as INT8 from the PREPROCESS stage. Today the
  //     staggered_fmap is 27-bit BF16 mantissa (v001 carryover); we
  //     truncate to its low 8 bits as a placeholder.
  logic [7:0] staggered_fmap_INT8 [0:`ARRAY_SIZE_H-1];
  genvar col_idx;
  generate
    for (col_idx = 0; col_idx < `ARRAY_SIZE_H; col_idx++) begin : act_trunc
      assign staggered_fmap_INT8[col_idx] = staggered_fmap[col_idx][7:0];
    end
  endgenerate

  GEMM_systolic_array #(
      .array_horizontal(`ARRAY_SIZE_H),
      .array_vertical  (`ARRAY_SIZE_V),
      .INT4_BITS       (`INT4_WIDTH),
      .INT8_BITS       (8),
      .B_PORT_W        (`DEVICE_DSP_B_WIDTH)
  ) u_compute_core (
      .clk(clk),
      .rst_n(rst_n),
      .i_clear(i_clear),
      .i_weight_valid(global_weight_valid),

      // Horizontal: distinct upper / lower INT4 weight streams.
      .H_in_upper(weight_upper),
      .H_in_lower(weight_lower),

      // Vertical: Feature Map (INT8 truncation placeholder) + Instructions.
      .V_in         (staggered_fmap_INT8),
      .in_valid     (staggered_fmap_valid),
      .inst_in      (staggered_inst),
      .inst_valid_in(staggered_inst_valid),

      .V_out      (raw_res_seq),
      .V_ACC_out  (raw_res_sum),
      .V_ACC_valid(raw_res_sum_valid)
  );

  // ===| e_max Delay Pipe for Normalization alignment |=======
  localparam TOTAL_LATENCY = `SYSTOLIC_TOTAL_LATENCY;
  logic [`BF16_EXP_WIDTH-1:0] emax_pipe[0:`ARRAY_SIZE_H-1][0:TOTAL_LATENCY-1];

  always_ff @(posedge clk) begin
    if (!rst_n) begin
      for (int c = 0; c < `ARRAY_SIZE_H; c++) begin
        for (int d = 0; d < TOTAL_LATENCY; d++) begin
          emax_pipe[c][d] <= 0;
        end
      end
    end else begin
      for (int c = 0; c < `ARRAY_SIZE_H; c++) begin
        emax_pipe[c][0] <= IN_cached_emax_out[c];
        for (int d = 1; d < TOTAL_LATENCY; d++) begin
          emax_pipe[c][d] <= emax_pipe[c][d-1];
        end
      end
    end
  end

  always_comb begin
    for (int c = 0; c < `ARRAY_SIZE_H; c++) begin
      delayed_emax_32[c] = emax_pipe[c][TOTAL_LATENCY-1];
    end
  end

endmodule

2. Vector Core — GEMV Top

GEMV_top.sv instantiates 4 parallel GEMV cores. Each core has a 32-wide LUT-based MAC and a 5-stage reduction tree (Stage 1 uses 16 DSP48E2 slices; Stages 2–5 are LUT adders). Weights stream from HP2/HP3.

Listing 11 hw/rtl/VEC_CORE/GEMV_top.sv
`timescale 1ns / 1ps

`include "GEMV_Vec_Matrix_MUL.svh"
`include "GLOBAL_CONST.svh"

// weight size = 4bit
// feature_map size =  bf16
module GEMV_top
  import vec_core_pkg::*;
#(
    parameter gemv_cfg_t param = VecCoreDefaultCfg,
    parameter A = 0,
    parameter B = 1,
    parameter C = 2,
    parameter D = 3
) (
    input logic clk,
    input logic rst_n,

    input logic IN_weight_valid_A,
    input logic IN_weight_valid_B,
    input logic IN_weight_valid_C,
    input logic IN_weight_valid_D,

    input logic [param.weight_width - 1:0] IN_weight_A[0:param.weight_cnt -1],
    input logic [param.weight_width - 1:0] IN_weight_B[0:param.weight_cnt -1],
    input logic [param.weight_width - 1:0] IN_weight_C[0:param.weight_cnt -1],
    input logic [param.weight_width - 1:0] IN_weight_D[0:param.weight_cnt -1],

    output logic OUT_weight_ready_A,
    output logic OUT_weight_ready_B,
    output logic OUT_weight_ready_C,
    output logic OUT_weight_ready_D,

    input logic [param.fixed_mant_width-1:0] IN_fmap_broadcast      [0:param.fmap_cache_out_cnt-1],
    input logic                              IN_fmap_broadcast_valid,

    input logic [16:0] IN_num_recur,
    // e_max (from Cache for Normalization alignment)
    input logic [dtype_pkg::Bf16ExpWidth-1:0] IN_cached_emax_out[0:param.fmap_cache_out_cnt-1],

    input logic IN_activated_lane[0:param.num_gemv_pipeline-1],

    // Per-lane batch vector. Width tracks `GEMV_reduction_branch`'s signed
    // fixed-point format (fixed_mant_width + 2 headroom + 1 sign).
    output logic [param.fixed_mant_width+2:0] OUT_final_fmap_A [0:param.gemv_batch-1],
    output logic [param.fixed_mant_width+2:0] OUT_final_fmap_B [0:param.gemv_batch-1],
    output logic [param.fixed_mant_width+2:0] OUT_final_fmap_C [0:param.gemv_batch-1],
    output logic [param.fixed_mant_width+2:0] OUT_final_fmap_D [0:param.gemv_batch-1],

    output logic OUT_result_valid_A,
    output logic OUT_result_valid_B,
    output logic OUT_result_valid_C,
    output logic OUT_result_valid_D
);

  logic signed [param.fixed_mant_width+2:0] fmap_LUT_wire[0:param.fmap_cache_out_cnt-1][0:param.weight_width-1];

  logic fmap_ready_wire;

  GEMV_generate_lut #(
      .param(VecCoreDefaultCfg)
  ) u_GEMV_generate_lut (
      .IN_fmap_broadcast(IN_fmap_broadcast),
      .IN_fmap_broadcast_valid(IN_fmap_broadcast_valid),
      .IN_cached_emax_out(IN_cached_emax_out),

      .OUT_fmap_LUT  (fmap_LUT_wire),
      .OUT_fmap_ready(fmap_ready_wire)
  );


  GEMV_reduction_branch #(
      .param(VecCoreDefaultCfg)
  ) u_GEMV_reduction_branch_A (
      .clk  (clk),
      .rst_n(rst_n),

      .IN_weight_valid(IN_weight_valid_A),
      .IN_weight(IN_weight_A),

      .fmap_ready(fmap_ready_wire),
      .IN_num_recur(IN_num_recur),  // shape x * y * z

      .IN_activated_lane(IN_activated_lane[A]),
      .IN_fmap_LUT(fmap_LUT_wire),

      .OUT_GEMV_result_vector(OUT_final_fmap_A),
      .OUT_valid(OUT_result_valid_A)
  );


  GEMV_reduction_branch #(
      .param(VecCoreDefaultCfg)
  ) u_GEMV_reduction_branch_B (
      .clk  (clk),
      .rst_n(rst_n),

      .IN_weight_valid(IN_weight_valid_B),
      .IN_weight(IN_weight_B),

      .fmap_ready(fmap_ready_wire),
      .IN_num_recur(IN_num_recur),  // shape x * y * z

      .IN_activated_lane(IN_activated_lane[B]),
      .IN_fmap_LUT(fmap_LUT_wire),

      .OUT_GEMV_result_vector(OUT_final_fmap_B),
      .OUT_valid(OUT_result_valid_B)
  );

  GEMV_reduction_branch #(
      .param(VecCoreDefaultCfg)
  ) u_GEMV_reduction_branch_C (
      .clk  (clk),
      .rst_n(rst_n),

      .IN_weight_valid(IN_weight_valid_C),
      .IN_weight(IN_weight_C),

      .fmap_ready(fmap_ready_wire),
      .IN_num_recur(IN_num_recur),  // shape x * y * z

      .IN_activated_lane(IN_activated_lane[C]),
      .IN_fmap_LUT(fmap_LUT_wire),

      .OUT_GEMV_result_vector(OUT_final_fmap_C),
      .OUT_valid(OUT_result_valid_C)
  );

  GEMV_reduction_branch #(
      .param(VecCoreDefaultCfg)
  ) u_GEMV_reduction_branch_D (
      .clk  (clk),
      .rst_n(rst_n),

      .IN_weight_valid(IN_weight_valid_D),
      .IN_weight(IN_weight_D),

      .fmap_ready(fmap_ready_wire),
      .IN_num_recur(IN_num_recur),  // shape x * y * z

      .IN_activated_lane(IN_activated_lane[D]),
      .IN_fmap_LUT(fmap_LUT_wire),

      .OUT_GEMV_result_vector(OUT_final_fmap_D),
      .OUT_valid(OUT_result_valid_D)
  );

endmodule

See also

GEMV Core

3. CVO / SFU Core

CVO_top.sv orchestrates the CORDIC + LUT hybrid units for non-linear operations: exp, sqrt, gelu, sin, cos, reduce_sum, scale, recip. Precision is promoted to BF16/FP32 for all computations.

Listing 12 hw/rtl/CVO_CORE/CVO_top.sv
`timescale 1ns / 1ps
`include "GLOBAL_CONST.svh"

import isa_pkg::*;
import bf16_math_pkg::*;

// ===| CVO Top |=================================================================
// Wraps CVO_sfu_unit (EXP/SQRT/GELU/RECIP/SCALE/REDUCE_SUM) and
// CVO_cordic_unit (SIN/COS) behind a unified streaming interface.
//
// Data flow:
//   Host issues OP_CVO via AXI-Lite → Global_Scheduler produces cvo_control_uop_t
//   → CVO_top latches uop, processes IN_length BF16 elements from L2 stream,
//     writes results back via output stream.
//
// FLAG_SUB_EMAX: subtract IN_e_max from each input before the function.
//   Implements exp(x - e_max) for numerically stable softmax.
// FLAG_ACCM: accumulate output into dst (add OUT_result to prior value).
//   Handled externally by the mem subsystem; CVO_top only signals it via OUT_accm.
//
// FSM states:
//   IDLE    : waiting for valid uop
//   RUNNING : streaming IN_length elements through the chosen unit
//   DONE    : pulse OUT_done for one cycle, return to IDLE
// ===============================================================================

module CVO_top (
    input  logic        clk,
    input  logic        rst_n,
    input  logic        i_clear,

    // ===| Dispatch from Global_Scheduler |=====================================
    input  cvo_control_uop_t IN_uop,
    input  logic             IN_uop_valid,
    output logic             OUT_uop_ready,

    // ===| BF16 Input Stream (from L2 via mem_dispatcher) |=====================
    input  logic [15:0]  IN_data,
    input  logic         IN_data_valid,
    output logic         OUT_data_ready,

    // ===| BF16 Output Stream (to L2 via mem_dispatcher) |=====================
    output logic [15:0]  OUT_result,
    output logic         OUT_result_valid,
    input  logic         IN_result_ready,

    // ===| e_max for FLAG_SUB_EMAX |============================================
    // Passed in as BF16; CVO subtracts this from each element before the function.
    input  logic [15:0]  IN_e_max,

    // ===| Status |=============================================================
    output logic         OUT_busy,
    output logic         OUT_done,
    output logic         OUT_accm   // mirrors IN_uop.flags.accm to mem subsystem
);

  // ===| FSM |===================================================================
  typedef enum logic [1:0] {
    ST_IDLE    = 2'b00,
    ST_RUNNING = 2'b01,
    ST_DONE    = 2'b10
  } cvo_state_e;

  cvo_state_e state;

  // ===| Latched UOP |===========================================================
  cvo_func_e   uop_func;
  cvo_flags_t  uop_flags;
  logic [15:0] uop_length;
  logic [15:0] elem_count;   // elements processed in current operation

  // ===| BF16 subtract e_max (combinational) |===================================
  // Implements x - e_max in BF16 via bf16_add(x, -e_max).

  logic [15:0] sub_emax_result_wire;

  always_comb begin : comb_sub_emax
    // Negate e_max by flipping sign bit, then add to x
    sub_emax_result_wire = bf16_add(IN_data, {~IN_e_max[15], IN_e_max[14:0]});
  end

  // ===| Input to sub-units (after optional e_max subtraction) |=================
  logic [15:0] data_to_unit_wire;
  logic        data_valid_to_unit_wire;

  always_comb begin
    data_to_unit_wire    = uop_flags.sub_emax ? sub_emax_result_wire : IN_data;
    data_valid_to_unit_wire = (state == ST_RUNNING) && IN_data_valid;
  end

  // ===| Opcode Routing (declared ahead of units that use it as a gating term) |
  logic is_cordic_op_wire;
  always_comb begin
    is_cordic_op_wire = (uop_func == CVO_SIN) || (uop_func == CVO_COS);
  end

  // ===| SFU Instantiation |=====================================================
  logic [15:0] sfu_result;
  logic        sfu_result_valid;
  logic        sfu_ready;

  CVO_sfu_unit u_CVO_sfu_unit (
      .clk             (clk),
      .rst_n           (rst_n),
      .i_clear         (i_clear),

      .IN_func         (uop_func),
      .IN_length       (uop_length),
      .IN_flags        (uop_flags),

      .IN_data         (data_to_unit_wire),
      .IN_valid        (data_valid_to_unit_wire && !is_cordic_op_wire),
      .OUT_data_ready  (sfu_ready),

      .OUT_result      (sfu_result),
      .OUT_result_valid(sfu_result_valid)
  );

  // ===| CORDIC Instantiation |==================================================
  logic [15:0] cordic_sin;
  logic [15:0] cordic_cos;
  logic        cordic_valid;

  CVO_cordic_unit u_CVO_cordic_unit (
      .clk          (clk),
      .rst_n        (rst_n),

      .IN_angle_bf16(data_to_unit_wire),
      .IN_valid     (data_valid_to_unit_wire && is_cordic_op_wire),

      .OUT_sin_bf16 (cordic_sin),
      .OUT_cos_bf16 (cordic_cos),
      .OUT_valid    (cordic_valid)
  );

  // ===| FSM Logic |=============================================================
  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      state      <= ST_IDLE;
      uop_func   <= CVO_EXP;
      uop_flags  <= '0;
      uop_length <= 16'd0;
      elem_count <= 16'd0;
      OUT_done   <= 1'b0;
    end else begin
      OUT_done <= 1'b0;

      case (state)
        // ===| IDLE: wait for dispatch |===
        ST_IDLE: begin
          if (IN_uop_valid) begin
            uop_func   <= IN_uop.cvo_func;
            uop_flags  <= IN_uop.flags;
            uop_length <= IN_uop.length;
            elem_count <= 16'd0;
            state      <= ST_RUNNING;
          end
        end

        // ===| RUNNING: count consumed elements |===
        ST_RUNNING: begin
          if (IN_data_valid && OUT_data_ready) begin
            elem_count <= elem_count + 16'd1;
            if (elem_count == uop_length - 16'd1) begin
              state    <= ST_DONE;
            end
          end
        end

        // ===| DONE: pulse and return |===
        ST_DONE: begin
          OUT_done <= 1'b1;
          state    <= ST_IDLE;
        end

        default: state <= ST_IDLE;
      endcase
    end
  end

  // ===| Output Mux |============================================================
  // CORDIC outputs two results per input; select sin or cos based on function.
  logic [15:0] result_mux_wire;
  logic        result_valid_mux_wire;

  always_comb begin
    if (is_cordic_op_wire) begin
      result_mux_wire       = (uop_func == CVO_SIN) ? cordic_sin : cordic_cos;
      result_valid_mux_wire = cordic_valid;
    end else begin
      result_mux_wire       = sfu_result;
      result_valid_mux_wire = sfu_result_valid;
    end
  end

  // ===| Output Registers |======================================================
  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      OUT_result       <= 16'd0;
      OUT_result_valid <= 1'b0;
    end else begin
      OUT_result       <= result_mux_wire;
      OUT_result_valid <= result_valid_mux_wire && IN_result_ready;
    end
  end

  // ===| Status & Control |======================================================
  assign OUT_busy      = (state != ST_IDLE);
  assign OUT_uop_ready = (state == ST_IDLE);
  assign OUT_data_ready = sfu_ready && (state == ST_RUNNING);
  assign OUT_accm      = uop_flags.accm;

endmodule

4. DSP48E2 MAC Unit

GEMM_dsp_unit.sv implements the dual-channel W4A8 MAC using a single DSP48E2 slice. See DSP48E2 W4A8 Bit Packing and Sign Recovery for the bit-packing derivation.

Listing 13 hw/rtl/MAT_CORE/GEMM_dsp_unit.sv
`timescale 1ns / 1ps

// ===============================================================================
// Module: GEMM_dsp_unit
// Phase : pccx v002 (W4A8, 1 DSP = 2 MAC)
//
// Role
// ----
//   A single PE of the 32 × 32 systolic array (with cascade break @ row 16).
//   Holds two stationary INT4 weights (w_upper / w_lower) and multiplies them
//   in parallel against an INT8 activation streaming vertically through the
//   column via B-port cascade, emitting two MACs per DSP per cycle.
//
// Data-flow summary
// -----------------
//   Weights (2 × INT4, horizontal shift-register along each row):
//     in_H_upper -> latch -> packer.w_upper -> A-port  -> out_H_upper
//     in_H_lower -> latch -> packer.w_lower -> A-port  -> out_H_lower
//
//   Activation (INT8, vertical cascade via BCIN/BCOUT):
//     top row : in_V       -> sign-ext(18) -> B  (B_INPUT = DIRECT)
//     others  : BCIN_in    -> cascade      -> B  (B_INPUT = CASCADE)
//     always  : BCOUT_out  -> next row BCIN
//
//   Partial sum (48-bit accumulator, P-port cascade):
//     top / normal rows : V_result_in via PCIN
//     break row (16)    : V_result_in via C-port (fabric), P_fabric_out
//                         feeds the merger back into the lower half.
//
// Notes
// -----
//   * A_INPUT is always DIRECT — weights are stationary per PE, not cascaded.
//   * BREG = 2 is kept so that activation can be shadow-loaded one cycle
//     ahead of the MAC fire, giving a steady one-cycle `o_valid` pipeline
//     latency. AREG = 1 is enough for weights (they rarely change).
//   * The accumulator drain window is bounded by the packer's guard band
//     (UPPER_SHIFT = 21, 1024 MACs per channel). The GEMM instruction
//     dispatcher is responsible for issuing a flush before that limit.
// ===============================================================================

`include "GLOBAL_CONST.svh"
`include "GEMM_Array.svh"

module GEMM_dsp_unit #(
  parameter IS_TOP_ROW    = 0,
  parameter BREAK_CASCADE = 0,  // 1 at row 16 of the 32-row physical array
  parameter INT4_BITS     = `INT4_WIDTH,           // 4
  parameter INT8_BITS     = 8,
  parameter A_PORT_W      = `DEVICE_DSP_A_WIDTH,   // 30
  parameter B_PORT_W      = `DEVICE_DSP_B_WIDTH,   // 18
  parameter P_PORT_W      = `DSP48E2_POUT_SIZE,    // 48
  parameter UPPER_SHIFT   = 21                     // matches GEMM_dsp_packer default
) (
  input  logic clk,
  input  logic rst_n,
  input  logic i_clear,

  input  logic i_valid,         // activation data valid
  input  logic i_weight_valid,  // latch a new pair of weights from in_H_*
  output logic o_valid,

  // ===| Horizontal weight shift-registers (two INT4 lanes per row) |============
  input  logic [INT4_BITS-1:0] in_H_upper,
  output logic [INT4_BITS-1:0] out_H_upper,
  input  logic [INT4_BITS-1:0] in_H_lower,
  output logic [INT4_BITS-1:0] out_H_lower,

  // ===| Vertical INT8 activation |==============================================
  //   Top row uses in_V (fabric). Non-top rows ignore in_V and take BCIN_in.
  input  logic [INT8_BITS-1:0] in_V,
  input  logic [B_PORT_W-1:0]  BCIN_in,
  output logic [B_PORT_W-1:0]  BCOUT_out,

  // ===| VLIW 3-bit instruction (vertical pipeline) |============================
  input  logic [2:0] instruction_in_V,
  output logic [2:0] instruction_out_V,
  input  logic       inst_valid_in_V,
  output logic       inst_valid_out_V,

  // ===| Partial-sum cascade (PCIN/PCOUT + optional fabric break) |==============
  input  logic [P_PORT_W-1:0] V_result_in,
  output logic [P_PORT_W-1:0] V_result_out,
  output logic [P_PORT_W-1:0] P_fabric_out
);

  // ===| Instruction latch |======================================================
  logic [2:0] current_inst;

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      current_inst <= 3'b000;
    end else if (inst_valid_in_V) begin
      current_inst <= instruction_in_V;
    end
  end

  always_ff @(posedge clk) begin
    if (!rst_n) begin
      instruction_out_V <= 3'b000;
      inst_valid_out_V  <= 1'b0;
    end else begin
      instruction_out_V <= instruction_in_V;
      inst_valid_out_V  <= inst_valid_in_V;
    end
  end

  // ===| Flush sequencer |========================================================
  //   Same shape as v001: a single-hot pulse walks down a 4-bit shift register
  //   whenever the dispatcher asserts instruction[2] (flush opcode). Stages 1-2
  //   clear the P-register, stage 3 latches the newly loaded weights.
  logic [3:0] flush_sequence;

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      flush_sequence <= 4'd0;
    end else begin
      flush_sequence <= {flush_sequence[2:0], 1'b0};
      if (inst_valid_in_V && instruction_in_V[2] == 1'b1) begin
        flush_sequence[0] <= 1'b1;
      end
    end
  end

  logic is_flushing;
  assign is_flushing = flush_sequence[1] | flush_sequence[2];

  // ===| OPMODE / ALUMODE |=======================================================
  //   Z-mux selects whether the P accumulator is continued from PCIN (cascade)
  //   or from the fabric C-port (after a cascade break at row 16).
  logic [8:0] dynamic_opmode;
  logic [3:0] dynamic_alumode;
  localparam logic [2:0] Z_MUX = BREAK_CASCADE ? 3'b011 : 3'b001;

  always_comb begin
    if (is_flushing) begin
      dynamic_opmode  = 9'b00_000_00_00;   // P = 0
      dynamic_alumode = 4'b0000;
    end else if (current_inst[0] == 1'b1) begin
      dynamic_opmode  = {2'b00, Z_MUX, 2'b01, 2'b01};  // P = P_prev + A*B
      dynamic_alumode = 4'b0000;
    end else begin
      dynamic_opmode  = {2'b00, Z_MUX, 2'b00, 2'b00};  // P = P_prev (pass)
      dynamic_alumode = 4'b0000;
    end
  end

  logic dsp_ce_p;
  assign dsp_ce_p = current_inst[0] | is_flushing;

  // ===| Weight latch + horizontal shift |========================================
  logic [INT4_BITS-1:0] w_upper_reg;
  logic [INT4_BITS-1:0] w_lower_reg;

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      w_upper_reg <= '0;
      w_lower_reg <= '0;
      out_H_upper <= '0;
      out_H_lower <= '0;
    end else if (i_weight_valid) begin
      w_upper_reg <= in_H_upper;
      w_lower_reg <= in_H_lower;
      out_H_upper <= in_H_upper;
      out_H_lower <= in_H_lower;
    end
  end

  // ===| Bit packing (2 × INT4 weights -> A port, INT8 act -> B port) |==========
  logic signed [A_PORT_W-1:0] a_packed;
  logic signed [B_PORT_W-1:0] b_extended;

  GEMM_dsp_packer #(
    .INT4_BITS   (INT4_BITS),
    .INT8_BITS   (INT8_BITS),
    .A_PORT_W    (A_PORT_W),
    .B_PORT_W    (B_PORT_W),
    .UPPER_SHIFT (UPPER_SHIFT)
  ) u_packer (
    .in_w_lower     (w_lower_reg),
    .in_w_upper     (w_upper_reg),
    .in_act         (in_V),
    .out_a_packed   (a_packed),
    .out_b_extended (b_extended)
  );

  // ===| Output-valid pipeline |==================================================
  logic valid_delay;
  always_ff @(posedge clk) begin
    if (!rst_n) valid_delay <= 1'b0;
    else        valid_delay <= i_valid;
  end
  assign o_valid = valid_delay;

  // ===| DSP48E2 primitive |======================================================
  //   A-port  : packer output (DIRECT, no cascade).
  //   B-port  : direct (top row) or BCIN cascade (non-top rows).
  //   P-port  : PCIN cascade or C-fabric when BREAK_CASCADE = 1.
  logic [P_PORT_W-1:0] p_internal;
  logic [P_PORT_W-1:0] dsp_c_input;
  logic [P_PORT_W-1:0] dsp_pcin_input;

  assign dsp_c_input    = BREAK_CASCADE ? V_result_in : '0;
  assign dsp_pcin_input = BREAK_CASCADE ? '0          : V_result_in;

  // Break row re-injects the activation from fabric (same pattern the v001
  // A-port used for BF16 mantissa). Otherwise BCIN cascade drives B.
  DSP48E2 #(
    .A_INPUT    ("DIRECT"),
    .B_INPUT    ((IS_TOP_ROW || BREAK_CASCADE) ? "DIRECT" : "CASCADE"),
    .AREG       (1),
    .BREG       (2),
    .CREG       (0),
    .MREG       (1),
    .PREG       (1),
    .OPMODEREG  (1),
    .ALUMODEREG (1),
    .USE_MULT   ("MULTIPLY")
  ) DSP_HARD_BLOCK (
    .CLK          (clk),
    .RSTA         (i_clear),
    .RSTB         (i_clear),
    .RSTM         (i_clear),
    .RSTP         (i_clear),
    .RSTCTRL      (i_clear),
    .RSTALLCARRYIN(i_clear),
    .RSTALUMODE   (i_clear),
    .RSTC         (i_clear),

    .CEA1     (i_weight_valid),
    .CEA2     (i_weight_valid),
    .CEB1     (i_valid),
    .CEB2     (i_valid),
    .CEM      (i_valid),
    .CEP      (dsp_ce_p),
    .CECTRL   (1'b1),
    .CEALUMODE(1'b1),
    .CEC      (1'b1),

    .A     (a_packed),
    .ACIN  ('0),
    .ACOUT (),

    .B     (b_extended),
    .BCIN  (BCIN_in),
    .BCOUT (BCOUT_out),

    .C     (dsp_c_input),

    .PCIN  (dsp_pcin_input),
    .PCOUT (V_result_out),

    .OPMODE (dynamic_opmode),
    .ALUMODE(dynamic_alumode),
    .P      (p_internal)
  );

  assign P_fabric_out = p_internal;

endmodule

Last verified against

Commit 773bd82 @ pccxai/pccx-FPGA-NPU-LLM-kv260 (2026-04-21).