通用大规模可复用的 systolic 矩阵乘法器

通用大规模可复用的 systolic 矩阵乘法器本文探讨了对 Systolic 阵列进行改进 特别是在矩阵计算中 通过控制数据流和利用流水线乘法器实现高效运算

大家好,欢迎来到IT知识分享网。

通用大规模可复用的 systolic 矩阵乘法器
systolic 矩阵乘法设计

在此基础上做了一点小扩充:

通用大规模可复用的 systolic 矩阵乘法器
通用 systolic  乘法器

模块设计:

通用大规模可复用的 systolic 矩阵乘法器
通用乘法器模块设计

在此基础上思考输入与输出的格式对齐方案:我使用了一种简单的设计,在计算完成时刻对整体的数据向下或右移出,这样不用添加额外的寄存器或读取转换电路,只需要将向右侧流动的数据位宽改为结果位宽即可,这样可以实现阵列的复用。

2.根据外部数据输出格式:修改输入输出的矩阵行列元素顺序:顺序输入–顺序输出、逆序输入–逆序输出

以下是输入时刻和输出时刻的数据流,本设计使用自定义无符号流水线乘法器

通用大规模可复用的 systolic 矩阵乘法器
输入
通用大规模可复用的 systolic 矩阵乘法器
输出
通用大规模可复用的 systolic 矩阵乘法器
最小单元设计

module pulse_arrays_pe #( parameter WIDTH_left = 8, parameter WIDTH_up = 8, parameter WIDTH_out = 8 )( input wire clk, input wire rst, input wire [1:0]mode, // mode=0 tpu_computer mode=1,2 shift out_data input wire [WIDTH_out-1:0] left, input wire [WIDTH_up-1:0] up, output reg [WIDTH_out-1:0] right, output reg [WIDTH_up-1:0] down, output reg [WIDTH_out-1:0] out_data ); //reg [WIDTH_out-1:0] out_data; wire [WIDTH_out-1:0] temp_data; always@(posedge clk)begin if(!rst)begin right <= 0; down <= 0; out_data <= 0; end else begin //computer if(mode==0)begin right<= left;//{ 
  {(WIDTH_out-WIDTH_left){1'd0}},left} down <= up; out_data <= temp_data+out_data; end //shift out_data else if(mode==2)begin right <= out_data; out_data <= 0; end else right <= left; end end //multiply module instantiation FIX_unsigned_MUL #(.WIDTH_multiplicand(WIDTH_left), .WIDTH_multiplier(WIDTH_up)) uut ( .clk (clk), .rst (rst), .valid (1'd1), .multiplicand (left[WIDTH_left-1:0]), .multiplier (up), .ready (), .product (temp_data) ); endmodule

通用阵列:

module pulse_arrays #( parameter WIDTH_left = 8, parameter WIDTH_up = 8, parameter WIDTH_out = 8, parameter Mritx_M = 3,//output row parameter Mritx_N = 3,// parameter Mritx_L = 3,//output col parameter Mritx_LOG2_size = 10 //counter's width )( input clk, input rst, input wire valid_left, input wire valid_up, input wire [Mritx_M*WIDTH_left-1:0] left, input wire [Mritx_L*WIDTH_up-1:0] up, output reg ready, //ready for input output wire [WIDTH_out*Mritx_M-1:0] product ); localparam idl = 4'd0; localparam state_in = 4'd1; localparam state_out= 4'd2; // state register reg [3:0] state; wire [WIDTH_out-1:0] left_temp [Mritx_M*Mritx_L-1:0]; //x_unite_wire wire [WIDTH_up-1:0] up_temp [Mritx_M*Mritx_L-1:0]; //y_unite_wire //enable signal // reg [Mritx_M-1:0] left_shift_en; // reg [Mritx_L-1:0] up_shift_en; reg star=0,export=0,finish=0;//flag reg [Mritx_LOG2_size-1:0] cnt_flow1; reg [Mritx_LOG2_size-1:0] cnt_flow2; reg [2*Mritx_L-1:0] mode_control=0;// 行同步控制 初始赋值0 大量变1易串扰不稳定 //output wire [WIDTH_out-1:0] out_data[Mritx_M*Mritx_L-1:0]; assign left_temp[0] = valid_left?{ 
  {(WIDTH_out-WIDTH_left){1'd0}},left[WIDTH_left-1:0]}:0; assign up_temp[0] = valid_up?up[WIDTH_up-1:0]:0; always @(posedge clk ) begin if(!rst)begin state <= idl; star <= 0; export<= 0; ready <= 0; finish<= 0; cnt_flow1 <= 0; cnt_flow2 <= 0; mode_control <= 0; end else begin case (state) idl :begin ready <= 1; star <= 0; export <= 0; finish <= 0; cnt_flow1 <= 0; cnt_flow2 <= 0; mode_control <= 0; if(valid_left&valid_up) state <= state_in; else state <= idl; end state_in :begin ready <= 0; star <= 1; export <= 0; finish <= 0; cnt_flow1 <= cnt_flow1 + 1; cnt_flow2 <= 0; if(cnt_flow1==Mritx_M+Mritx_N+Mritx_L-2+WIDTH_up-1)begin state <= state_out; mode_control <= {Mritx_L{2'd2}}; end else state <= state_in; end state_out:begin ready <= 0; star <= 0; cnt_flow1 <= 0; cnt_flow2 <= cnt_flow2 + 1; if(cnt_flow2==0) mode_control <= {Mritx_L{2'd1}}<<2; else mode_control <= mode_control<<2; if(cnt_flow2==Mritx_L) begin state <= idl; finish<= 1; export<= 0; end else begin state <= state_out; finish<= 0; export<= 1; end end default: begin ready <= 0; star <= 0; export <= 0; finish <= 0; cnt_flow1 <= 0; cnt_flow2 <= 0; mode_control <= 0; state <= idl; end endcase end end generate genvar i,j; for(i=0;i<=Mritx_M-1;i=i+1)begin for(j=0;j<=Mritx_L-1;j=j+1)begin:pulse_arrays_pex if(i==0&&j==0)begin pulse_arrays_pe #( .WIDTH_left (WIDTH_left), .WIDTH_up (WIDTH_up), .WIDTH_out (WIDTH_out) )pulse_arrays_pex( .clk (clk), .rst (rst), .mode (mode_control[(j+1)*2-1:j*2]), .left (left_temp[0]), .up (up_temp[0]), .right (left_temp[i*Mritx_L+j+1]), .down (up_temp[(i+1)*Mritx_L+j]), .out_data (out_data[i*Mritx_L+j]) ); end else if(i==Mritx_M-1&&j==Mritx_L-1)begin pulse_arrays_pe #( .WIDTH_left (WIDTH_left), .WIDTH_up (WIDTH_up), .WIDTH_out (WIDTH_out) )pulse_arrays_pex( .clk (clk), .rst (rst), .mode (mode_control[(j+1)*2-1:j*2]), .left (left_temp[i*Mritx_L+j]), .up (up_temp[i*Mritx_L+j]), .right (product[WIDTH_out*(i+1)-1:WIDTH_out*i]), .down (), .out_data (out_data[i*Mritx_L+j]) ); end else if(i==Mritx_M-1)begin pulse_arrays_pe #( .WIDTH_left (WIDTH_left), .WIDTH_up (WIDTH_up), .WIDTH_out (WIDTH_out) )pulse_arrays_pex( .clk (clk), .rst (rst), .mode (mode_control[(j+1)*2-1:j*2]), .left (left_temp[i*Mritx_L+j]), .up (up_temp[i*Mritx_L+j]), .right (left_temp[i*Mritx_L+j+1]), .down (), .out_data (out_data[i*Mritx_L+j]) ); end else if(j==Mritx_L-1)begin pulse_arrays_pe #( .WIDTH_left (WIDTH_left), .WIDTH_up (WIDTH_up), .WIDTH_out (WIDTH_out) )pulse_arrays_pex( .clk (clk), .rst (rst), .mode (mode_control[(j+1)*2-1:j*2]), .left (left_temp[i*Mritx_L+j]), .up (up_temp[i*Mritx_L+j]), .right (product[WIDTH_out*(i+1)-1:WIDTH_out*i]), .down (up_temp[(i+1)*Mritx_L+j]), .out_data (out_data[i*Mritx_L+j]) ); end else begin pulse_arrays_pe #( .WIDTH_left (WIDTH_left), .WIDTH_up (WIDTH_up), .WIDTH_out (WIDTH_out) )pulse_arrays_pex( .clk (clk), .rst (rst), .mode (mode_control[(j+1)*2-1:j*2]), .left (left_temp[i*Mritx_L+j]), .up (up_temp[i*Mritx_L+j]), .right (left_temp[i*Mritx_L+j+1]), .down (up_temp[(i+1)*Mritx_L+j]), .out_data (out_data[i*Mritx_L+j]) ); end end end endgenerate generate genvar m,n; for(m=1;m<Mritx_M;m=m+1)begin:shift_register_left shift_register #( .WIDTH_in(WIDTH_left), .WIDTH_out(WIDTH_out), .DEEP(m), .PTR_SIZE(Mritx_LOG2_size) )shift_register_left( .clk(clk), .rst(rst), .shift_en(valid_left), .shift_in(left[(m+1)*WIDTH_left-1:(m)*WIDTH_left]), .shift_out(left_temp[m*Mritx_L]) ); end for(n=1;n<Mritx_L;n=n+1)begin:shift_register_up shift_register #( .WIDTH_in(WIDTH_up), .WIDTH_out(WIDTH_up), .DEEP(n), .PTR_SIZE(Mritx_LOG2_size) )shift_register_up( .clk(clk), .rst(rst), .shift_en(valid_up), .shift_in(up[(n+1)*WIDTH_up-1:n*WIDTH_up]), .shift_out(up_temp[n]) ); end endgenerate endmodule

要想在这个架构上实现真正的流水线矩阵乘法器,需要按照第一张图中的单元计算完成梯度图来将数据读出,但我目前并没有找到在不消耗大量硬件资源的情况下可通用的方法,后续找到的话会继续更新。

仿真:

通用大规模可复用的 systolic 矩阵乘法器通用大规模可复用的 systolic 矩阵乘法器

通用大规模可复用的 systolic 矩阵乘法器

通用大规模可复用的 systolic 矩阵乘法器

通用大规模可复用的 systolic 矩阵乘法器

通用大规模可复用的 systolic 矩阵乘法器

通用大规模可复用的 systolic 矩阵乘法器

本项目开源所有代码已上传至gihub仓库:Debug-xmh/Systiolic-Matrix-multiplier: Universal matrix multiplier, an improvement on the design of Systiolic Matrix multiplier in verilog. (github.com)

免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://haidsoft.com/149328.html

(0)
上一篇 2025-03-24 17:45
下一篇 2025-03-24 18:00

相关推荐

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注微信