38103 lines
1.6 MiB
38103 lines
1.6 MiB
// Copyright 2015 Google Inc. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
// single_thread_gemm.h: programatically generated GEMM library header.
|
|
|
|
#ifndef GEMMLOWP_META_SINGLE_THREAD_GEMM_H_
|
|
#define GEMMLOWP_META_SINGLE_THREAD_GEMM_H_
|
|
|
|
#ifdef GEMMLOWP_NEON_32
|
|
|
|
#include <cassert>
|
|
|
|
namespace gemmlowp {
|
|
namespace meta {
|
|
namespace internal {
|
|
|
|
void zip_1x8_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"vmov.i16 q2, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d1[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpadd.u32 d6, d4, d5\n"
|
|
"vpadd.u32 d8, d6, d6\n"
|
|
"vmul.i32 q4, q4, d1[0]\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vst1.32 {d8[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
|
|
}
|
|
|
|
void zip_1x8_1_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"sub %[count], %[count], #1\n"
|
|
"vmov.i16 q2, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vld1.8 {d0[0]}, [%[source]]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d1[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpadd.u32 d6, d4, d5\n"
|
|
"vpadd.u32 d8, d6, d6\n"
|
|
"vmul.i32 q4, q4, d1[0]\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vst1.32 {d8[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
|
|
}
|
|
|
|
void zip_1x8_2_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"sub %[count], %[count], #2\n"
|
|
"vmov.i16 q2, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vld1.16 {d0[0]}, [%[source]]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d1[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpadd.u32 d6, d4, d5\n"
|
|
"vpadd.u32 d8, d6, d6\n"
|
|
"vmul.i32 q4, q4, d1[0]\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vst1.32 {d8[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
|
|
}
|
|
|
|
void zip_1x8_3_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"sub %[count], %[count], #3\n"
|
|
"vmov.i16 q2, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vld1.16 {d0[0]}, [%[source]]!\n"
|
|
"vld1.8 {d0[2]}, [%[source]]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d1[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpadd.u32 d6, d4, d5\n"
|
|
"vpadd.u32 d8, d6, d6\n"
|
|
"vmul.i32 q4, q4, d1[0]\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vst1.32 {d8[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
|
|
}
|
|
|
|
void zip_1x8_4_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"sub %[count], %[count], #4\n"
|
|
"vmov.i16 q2, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d1[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpadd.u32 d6, d4, d5\n"
|
|
"vpadd.u32 d8, d6, d6\n"
|
|
"vmul.i32 q4, q4, d1[0]\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vst1.32 {d8[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
|
|
}
|
|
|
|
void zip_1x8_5_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"sub %[count], %[count], #5\n"
|
|
"vmov.i16 q2, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]!\n"
|
|
"vld1.8 {d0[4]}, [%[source]]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d1[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpadd.u32 d6, d4, d5\n"
|
|
"vpadd.u32 d8, d6, d6\n"
|
|
"vmul.i32 q4, q4, d1[0]\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vst1.32 {d8[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
|
|
}
|
|
|
|
void zip_1x8_6_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"sub %[count], %[count], #6\n"
|
|
"vmov.i16 q2, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]!\n"
|
|
"vld1.16 {d0[2]}, [%[source]]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d1[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpadd.u32 d6, d4, d5\n"
|
|
"vpadd.u32 d8, d6, d6\n"
|
|
"vmul.i32 q4, q4, d1[0]\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vst1.32 {d8[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
|
|
}
|
|
|
|
void zip_1x8_7_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"sub %[count], %[count], #7\n"
|
|
"vmov.i16 q2, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]!\n"
|
|
"vld1.16 {d0[2]}, [%[source]]!\n"
|
|
"vld1.8 {d0[6]}, [%[source]]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d1[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpadd.u32 d6, d4, d5\n"
|
|
"vpadd.u32 d8, d6, d6\n"
|
|
"vmul.i32 q4, q4, d1[0]\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vst1.32 {d8[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
|
|
}
|
|
|
|
void zip_2x8_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vld1.8 {d1}, [r0:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d2[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q4, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpadd.u32 d3, d4, d5\n"
|
|
"vpadd.u32 d10, d6, d7\n"
|
|
"vpadd.u32 d12, d3, d10\n"
|
|
"vmul.i32 q6, q6, d2[0]\n"
|
|
"vadd.i32 q6, q6, q4\n"
|
|
"vst1.32 {d12}, [%[destination]:64]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void zip_2x8_1_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"sub %[count], %[count], #1\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vld1.8 {d1}, [r0:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vld1.8 {d0[0]}, [%[source]]\n"
|
|
"vld1.8 {d1[0]}, [r0]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d2[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q4, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpadd.u32 d3, d4, d5\n"
|
|
"vpadd.u32 d10, d6, d7\n"
|
|
"vpadd.u32 d12, d3, d10\n"
|
|
"vmul.i32 q6, q6, d2[0]\n"
|
|
"vadd.i32 q6, q6, q4\n"
|
|
"vst1.32 {d12}, [%[destination]:64]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void zip_2x8_2_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"sub %[count], %[count], #2\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vld1.8 {d1}, [r0:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vld1.16 {d0[0]}, [%[source]]\n"
|
|
"vld1.16 {d1[0]}, [r0]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d2[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q4, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpadd.u32 d3, d4, d5\n"
|
|
"vpadd.u32 d10, d6, d7\n"
|
|
"vpadd.u32 d12, d3, d10\n"
|
|
"vmul.i32 q6, q6, d2[0]\n"
|
|
"vadd.i32 q6, q6, q4\n"
|
|
"vst1.32 {d12}, [%[destination]:64]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void zip_2x8_3_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"sub %[count], %[count], #3\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vld1.8 {d1}, [r0:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vld1.16 {d0[0]}, [%[source]]!\n"
|
|
"vld1.16 {d1[0]}, [r0]!\n"
|
|
"vld1.8 {d0[2]}, [%[source]]\n"
|
|
"vld1.8 {d1[2]}, [r0]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d2[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q4, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpadd.u32 d3, d4, d5\n"
|
|
"vpadd.u32 d10, d6, d7\n"
|
|
"vpadd.u32 d12, d3, d10\n"
|
|
"vmul.i32 q6, q6, d2[0]\n"
|
|
"vadd.i32 q6, q6, q4\n"
|
|
"vst1.32 {d12}, [%[destination]:64]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void zip_2x8_4_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"sub %[count], %[count], #4\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vld1.8 {d1}, [r0:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]\n"
|
|
"vld1.32 {d1[0]}, [r0]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d2[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q4, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpadd.u32 d3, d4, d5\n"
|
|
"vpadd.u32 d10, d6, d7\n"
|
|
"vpadd.u32 d12, d3, d10\n"
|
|
"vmul.i32 q6, q6, d2[0]\n"
|
|
"vadd.i32 q6, q6, q4\n"
|
|
"vst1.32 {d12}, [%[destination]:64]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void zip_2x8_5_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"sub %[count], %[count], #5\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vld1.8 {d1}, [r0:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]!\n"
|
|
"vld1.32 {d1[0]}, [r0]!\n"
|
|
"vld1.8 {d0[4]}, [%[source]]\n"
|
|
"vld1.8 {d1[4]}, [r0]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d2[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q4, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpadd.u32 d3, d4, d5\n"
|
|
"vpadd.u32 d10, d6, d7\n"
|
|
"vpadd.u32 d12, d3, d10\n"
|
|
"vmul.i32 q6, q6, d2[0]\n"
|
|
"vadd.i32 q6, q6, q4\n"
|
|
"vst1.32 {d12}, [%[destination]:64]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void zip_2x8_6_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"sub %[count], %[count], #6\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vld1.8 {d1}, [r0:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]!\n"
|
|
"vld1.32 {d1[0]}, [r0]!\n"
|
|
"vld1.16 {d0[2]}, [%[source]]\n"
|
|
"vld1.16 {d1[2]}, [r0]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d2[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q4, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpadd.u32 d3, d4, d5\n"
|
|
"vpadd.u32 d10, d6, d7\n"
|
|
"vpadd.u32 d12, d3, d10\n"
|
|
"vmul.i32 q6, q6, d2[0]\n"
|
|
"vadd.i32 q6, q6, q4\n"
|
|
"vst1.32 {d12}, [%[destination]:64]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void zip_2x8_7_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"sub %[count], %[count], #7\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vld1.8 {d1}, [r0:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]!\n"
|
|
"vld1.32 {d1[0]}, [r0]!\n"
|
|
"vld1.16 {d0[2]}, [%[source]]!\n"
|
|
"vld1.16 {d1[2]}, [r0]!\n"
|
|
"vld1.8 {d0[6]}, [%[source]]\n"
|
|
"vld1.8 {d1[6]}, [r0]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d2[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q4, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpadd.u32 d3, d4, d5\n"
|
|
"vpadd.u32 d10, d6, d7\n"
|
|
"vpadd.u32 d12, d3, d10\n"
|
|
"vmul.i32 q6, q6, d2[0]\n"
|
|
"vadd.i32 q6, q6, q4\n"
|
|
"vst1.32 {d12}, [%[destination]:64]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void zip_3x8_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, r0, %[stride]\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
"vmov.i16 q4, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vld1.8 {d1}, [r0:64]!\n"
|
|
"vld1.8 {d2}, [r1:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d3[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q5, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpaddl.u16 q4, q4\n"
|
|
"vpadd.u32 d12, d4, d5\n"
|
|
"vpadd.u32 d13, d6, d7\n"
|
|
"vpadd.u32 d14, d8, d9\n"
|
|
"vpadd.u32 d16, d12, d13\n"
|
|
"vpadd.u32 d17, d14, d14\n"
|
|
"vmul.i32 q8, q8, d3[0]\n"
|
|
"vadd.i32 q8, q8, q5\n"
|
|
"vst1.32 {d16}, [%[destination]:64]!\n"
|
|
"vst1.32 {d17[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
|
|
}
|
|
|
|
void zip_3x8_1_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, r0, %[stride]\n"
|
|
"sub %[count], %[count], #1\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
"vmov.i16 q4, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vld1.8 {d1}, [r0:64]!\n"
|
|
"vld1.8 {d2}, [r1:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vmov.i8 d2, #0\n"
|
|
"vld1.8 {d0[0]}, [%[source]]\n"
|
|
"vld1.8 {d1[0]}, [r0]\n"
|
|
"vld1.8 {d2[0]}, [r1]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d3[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q5, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpaddl.u16 q4, q4\n"
|
|
"vpadd.u32 d12, d4, d5\n"
|
|
"vpadd.u32 d13, d6, d7\n"
|
|
"vpadd.u32 d14, d8, d9\n"
|
|
"vpadd.u32 d16, d12, d13\n"
|
|
"vpadd.u32 d17, d14, d14\n"
|
|
"vmul.i32 q8, q8, d3[0]\n"
|
|
"vadd.i32 q8, q8, q5\n"
|
|
"vst1.32 {d16}, [%[destination]:64]!\n"
|
|
"vst1.32 {d17[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
|
|
}
|
|
|
|
void zip_3x8_2_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, r0, %[stride]\n"
|
|
"sub %[count], %[count], #2\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
"vmov.i16 q4, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vld1.8 {d1}, [r0:64]!\n"
|
|
"vld1.8 {d2}, [r1:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vmov.i8 d2, #0\n"
|
|
"vld1.16 {d0[0]}, [%[source]]\n"
|
|
"vld1.16 {d1[0]}, [r0]\n"
|
|
"vld1.16 {d2[0]}, [r1]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d3[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q5, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpaddl.u16 q4, q4\n"
|
|
"vpadd.u32 d12, d4, d5\n"
|
|
"vpadd.u32 d13, d6, d7\n"
|
|
"vpadd.u32 d14, d8, d9\n"
|
|
"vpadd.u32 d16, d12, d13\n"
|
|
"vpadd.u32 d17, d14, d14\n"
|
|
"vmul.i32 q8, q8, d3[0]\n"
|
|
"vadd.i32 q8, q8, q5\n"
|
|
"vst1.32 {d16}, [%[destination]:64]!\n"
|
|
"vst1.32 {d17[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
|
|
}
|
|
|
|
void zip_3x8_3_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, r0, %[stride]\n"
|
|
"sub %[count], %[count], #3\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
"vmov.i16 q4, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vld1.8 {d1}, [r0:64]!\n"
|
|
"vld1.8 {d2}, [r1:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vmov.i8 d2, #0\n"
|
|
"vld1.16 {d0[0]}, [%[source]]!\n"
|
|
"vld1.16 {d1[0]}, [r0]!\n"
|
|
"vld1.16 {d2[0]}, [r1]!\n"
|
|
"vld1.8 {d0[2]}, [%[source]]\n"
|
|
"vld1.8 {d1[2]}, [r0]\n"
|
|
"vld1.8 {d2[2]}, [r1]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d3[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q5, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpaddl.u16 q4, q4\n"
|
|
"vpadd.u32 d12, d4, d5\n"
|
|
"vpadd.u32 d13, d6, d7\n"
|
|
"vpadd.u32 d14, d8, d9\n"
|
|
"vpadd.u32 d16, d12, d13\n"
|
|
"vpadd.u32 d17, d14, d14\n"
|
|
"vmul.i32 q8, q8, d3[0]\n"
|
|
"vadd.i32 q8, q8, q5\n"
|
|
"vst1.32 {d16}, [%[destination]:64]!\n"
|
|
"vst1.32 {d17[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
|
|
}
|
|
|
|
void zip_3x8_4_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, r0, %[stride]\n"
|
|
"sub %[count], %[count], #4\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
"vmov.i16 q4, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vld1.8 {d1}, [r0:64]!\n"
|
|
"vld1.8 {d2}, [r1:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vmov.i8 d2, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]\n"
|
|
"vld1.32 {d1[0]}, [r0]\n"
|
|
"vld1.32 {d2[0]}, [r1]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d3[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q5, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpaddl.u16 q4, q4\n"
|
|
"vpadd.u32 d12, d4, d5\n"
|
|
"vpadd.u32 d13, d6, d7\n"
|
|
"vpadd.u32 d14, d8, d9\n"
|
|
"vpadd.u32 d16, d12, d13\n"
|
|
"vpadd.u32 d17, d14, d14\n"
|
|
"vmul.i32 q8, q8, d3[0]\n"
|
|
"vadd.i32 q8, q8, q5\n"
|
|
"vst1.32 {d16}, [%[destination]:64]!\n"
|
|
"vst1.32 {d17[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
|
|
}
|
|
|
|
void zip_3x8_5_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, r0, %[stride]\n"
|
|
"sub %[count], %[count], #5\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
"vmov.i16 q4, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vld1.8 {d1}, [r0:64]!\n"
|
|
"vld1.8 {d2}, [r1:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vmov.i8 d2, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]!\n"
|
|
"vld1.32 {d1[0]}, [r0]!\n"
|
|
"vld1.32 {d2[0]}, [r1]!\n"
|
|
"vld1.8 {d0[4]}, [%[source]]\n"
|
|
"vld1.8 {d1[4]}, [r0]\n"
|
|
"vld1.8 {d2[4]}, [r1]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d3[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q5, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpaddl.u16 q4, q4\n"
|
|
"vpadd.u32 d12, d4, d5\n"
|
|
"vpadd.u32 d13, d6, d7\n"
|
|
"vpadd.u32 d14, d8, d9\n"
|
|
"vpadd.u32 d16, d12, d13\n"
|
|
"vpadd.u32 d17, d14, d14\n"
|
|
"vmul.i32 q8, q8, d3[0]\n"
|
|
"vadd.i32 q8, q8, q5\n"
|
|
"vst1.32 {d16}, [%[destination]:64]!\n"
|
|
"vst1.32 {d17[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
|
|
}
|
|
|
|
void zip_3x8_6_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, r0, %[stride]\n"
|
|
"sub %[count], %[count], #6\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
"vmov.i16 q4, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vld1.8 {d1}, [r0:64]!\n"
|
|
"vld1.8 {d2}, [r1:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vmov.i8 d2, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]!\n"
|
|
"vld1.32 {d1[0]}, [r0]!\n"
|
|
"vld1.32 {d2[0]}, [r1]!\n"
|
|
"vld1.16 {d0[2]}, [%[source]]\n"
|
|
"vld1.16 {d1[2]}, [r0]\n"
|
|
"vld1.16 {d2[2]}, [r1]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d3[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q5, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpaddl.u16 q4, q4\n"
|
|
"vpadd.u32 d12, d4, d5\n"
|
|
"vpadd.u32 d13, d6, d7\n"
|
|
"vpadd.u32 d14, d8, d9\n"
|
|
"vpadd.u32 d16, d12, d13\n"
|
|
"vpadd.u32 d17, d14, d14\n"
|
|
"vmul.i32 q8, q8, d3[0]\n"
|
|
"vadd.i32 q8, q8, q5\n"
|
|
"vst1.32 {d16}, [%[destination]:64]!\n"
|
|
"vst1.32 {d17[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
|
|
}
|
|
|
|
void zip_3x8_7_aligned(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, r0, %[stride]\n"
|
|
"sub %[count], %[count], #7\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
"vmov.i16 q4, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]:64]!\n"
|
|
"vld1.8 {d1}, [r0:64]!\n"
|
|
"vld1.8 {d2}, [r1:64]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vmov.i8 d2, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]!\n"
|
|
"vld1.32 {d1[0]}, [r0]!\n"
|
|
"vld1.32 {d2[0]}, [r1]!\n"
|
|
"vld1.16 {d0[2]}, [%[source]]!\n"
|
|
"vld1.16 {d1[2]}, [r0]!\n"
|
|
"vld1.16 {d2[2]}, [r1]!\n"
|
|
"vld1.8 {d0[6]}, [%[source]]\n"
|
|
"vld1.8 {d1[6]}, [r0]\n"
|
|
"vld1.8 {d2[6]}, [r1]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d3[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q5, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpaddl.u16 q4, q4\n"
|
|
"vpadd.u32 d12, d4, d5\n"
|
|
"vpadd.u32 d13, d6, d7\n"
|
|
"vpadd.u32 d14, d8, d9\n"
|
|
"vpadd.u32 d16, d12, d13\n"
|
|
"vpadd.u32 d17, d14, d14\n"
|
|
"vmul.i32 q8, q8, d3[0]\n"
|
|
"vadd.i32 q8, q8, q5\n"
|
|
"vst1.32 {d16}, [%[destination]:64]!\n"
|
|
"vst1.32 {d17[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
|
|
}
|
|
|
|
void zip_1x8(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset, std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"vmov.i16 q2, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d1[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpadd.u32 d6, d4, d5\n"
|
|
"vpadd.u32 d8, d6, d6\n"
|
|
"vmul.i32 q4, q4, d1[0]\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vst1.32 {d8[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
|
|
}
|
|
|
|
void zip_1x8_1(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"sub %[count], %[count], #1\n"
|
|
"vmov.i16 q2, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vld1.8 {d0[0]}, [%[source]]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d1[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpadd.u32 d6, d4, d5\n"
|
|
"vpadd.u32 d8, d6, d6\n"
|
|
"vmul.i32 q4, q4, d1[0]\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vst1.32 {d8[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
|
|
}
|
|
|
|
void zip_1x8_2(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"sub %[count], %[count], #2\n"
|
|
"vmov.i16 q2, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vld1.16 {d0[0]}, [%[source]]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d1[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpadd.u32 d6, d4, d5\n"
|
|
"vpadd.u32 d8, d6, d6\n"
|
|
"vmul.i32 q4, q4, d1[0]\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vst1.32 {d8[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
|
|
}
|
|
|
|
void zip_1x8_3(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"sub %[count], %[count], #3\n"
|
|
"vmov.i16 q2, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vld1.16 {d0[0]}, [%[source]]!\n"
|
|
"vld1.8 {d0[2]}, [%[source]]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d1[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpadd.u32 d6, d4, d5\n"
|
|
"vpadd.u32 d8, d6, d6\n"
|
|
"vmul.i32 q4, q4, d1[0]\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vst1.32 {d8[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
|
|
}
|
|
|
|
void zip_1x8_4(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"sub %[count], %[count], #4\n"
|
|
"vmov.i16 q2, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d1[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpadd.u32 d6, d4, d5\n"
|
|
"vpadd.u32 d8, d6, d6\n"
|
|
"vmul.i32 q4, q4, d1[0]\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vst1.32 {d8[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
|
|
}
|
|
|
|
void zip_1x8_5(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"sub %[count], %[count], #5\n"
|
|
"vmov.i16 q2, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]!\n"
|
|
"vld1.8 {d0[4]}, [%[source]]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d1[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpadd.u32 d6, d4, d5\n"
|
|
"vpadd.u32 d8, d6, d6\n"
|
|
"vmul.i32 q4, q4, d1[0]\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vst1.32 {d8[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
|
|
}
|
|
|
|
void zip_1x8_6(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"sub %[count], %[count], #6\n"
|
|
"vmov.i16 q2, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]!\n"
|
|
"vld1.16 {d0[2]}, [%[source]]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d1[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpadd.u32 d6, d4, d5\n"
|
|
"vpadd.u32 d8, d6, d6\n"
|
|
"vmul.i32 q4, q4, d1[0]\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vst1.32 {d8[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
|
|
}
|
|
|
|
void zip_1x8_7(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"sub %[count], %[count], #7\n"
|
|
"vmov.i16 q2, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]!\n"
|
|
"vld1.16 {d0[2]}, [%[source]]!\n"
|
|
"vld1.8 {d0[6]}, [%[source]]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vst1.8 {d0}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d1[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpadd.u32 d6, d4, d5\n"
|
|
"vpadd.u32 d8, d6, d6\n"
|
|
"vmul.i32 q4, q4, d1[0]\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vst1.32 {d8[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "cc", "memory");
|
|
}
|
|
|
|
void zip_2x8(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset, std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vld1.8 {d1}, [r0]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d2[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q4, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpadd.u32 d3, d4, d5\n"
|
|
"vpadd.u32 d10, d6, d7\n"
|
|
"vpadd.u32 d12, d3, d10\n"
|
|
"vmul.i32 q6, q6, d2[0]\n"
|
|
"vadd.i32 q6, q6, q4\n"
|
|
"vst1.32 {d12}, [%[destination]:64]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void zip_2x8_1(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"sub %[count], %[count], #1\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vld1.8 {d1}, [r0]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vld1.8 {d0[0]}, [%[source]]\n"
|
|
"vld1.8 {d1[0]}, [r0]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d2[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q4, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpadd.u32 d3, d4, d5\n"
|
|
"vpadd.u32 d10, d6, d7\n"
|
|
"vpadd.u32 d12, d3, d10\n"
|
|
"vmul.i32 q6, q6, d2[0]\n"
|
|
"vadd.i32 q6, q6, q4\n"
|
|
"vst1.32 {d12}, [%[destination]:64]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void zip_2x8_2(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"sub %[count], %[count], #2\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vld1.8 {d1}, [r0]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vld1.16 {d0[0]}, [%[source]]\n"
|
|
"vld1.16 {d1[0]}, [r0]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d2[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q4, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpadd.u32 d3, d4, d5\n"
|
|
"vpadd.u32 d10, d6, d7\n"
|
|
"vpadd.u32 d12, d3, d10\n"
|
|
"vmul.i32 q6, q6, d2[0]\n"
|
|
"vadd.i32 q6, q6, q4\n"
|
|
"vst1.32 {d12}, [%[destination]:64]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void zip_2x8_3(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"sub %[count], %[count], #3\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vld1.8 {d1}, [r0]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vld1.16 {d0[0]}, [%[source]]!\n"
|
|
"vld1.16 {d1[0]}, [r0]!\n"
|
|
"vld1.8 {d0[2]}, [%[source]]\n"
|
|
"vld1.8 {d1[2]}, [r0]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d2[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q4, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpadd.u32 d3, d4, d5\n"
|
|
"vpadd.u32 d10, d6, d7\n"
|
|
"vpadd.u32 d12, d3, d10\n"
|
|
"vmul.i32 q6, q6, d2[0]\n"
|
|
"vadd.i32 q6, q6, q4\n"
|
|
"vst1.32 {d12}, [%[destination]:64]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void zip_2x8_4(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"sub %[count], %[count], #4\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vld1.8 {d1}, [r0]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]\n"
|
|
"vld1.32 {d1[0]}, [r0]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d2[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q4, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpadd.u32 d3, d4, d5\n"
|
|
"vpadd.u32 d10, d6, d7\n"
|
|
"vpadd.u32 d12, d3, d10\n"
|
|
"vmul.i32 q6, q6, d2[0]\n"
|
|
"vadd.i32 q6, q6, q4\n"
|
|
"vst1.32 {d12}, [%[destination]:64]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void zip_2x8_5(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"sub %[count], %[count], #5\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vld1.8 {d1}, [r0]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]!\n"
|
|
"vld1.32 {d1[0]}, [r0]!\n"
|
|
"vld1.8 {d0[4]}, [%[source]]\n"
|
|
"vld1.8 {d1[4]}, [r0]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d2[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q4, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpadd.u32 d3, d4, d5\n"
|
|
"vpadd.u32 d10, d6, d7\n"
|
|
"vpadd.u32 d12, d3, d10\n"
|
|
"vmul.i32 q6, q6, d2[0]\n"
|
|
"vadd.i32 q6, q6, q4\n"
|
|
"vst1.32 {d12}, [%[destination]:64]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void zip_2x8_6(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"sub %[count], %[count], #6\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vld1.8 {d1}, [r0]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]!\n"
|
|
"vld1.32 {d1[0]}, [r0]!\n"
|
|
"vld1.16 {d0[2]}, [%[source]]\n"
|
|
"vld1.16 {d1[2]}, [r0]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d2[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q4, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpadd.u32 d3, d4, d5\n"
|
|
"vpadd.u32 d10, d6, d7\n"
|
|
"vpadd.u32 d12, d3, d10\n"
|
|
"vmul.i32 q6, q6, d2[0]\n"
|
|
"vadd.i32 q6, q6, q4\n"
|
|
"vst1.32 {d12}, [%[destination]:64]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void zip_2x8_7(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"sub %[count], %[count], #7\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vld1.8 {d1}, [r0]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]!\n"
|
|
"vld1.32 {d1[0]}, [r0]!\n"
|
|
"vld1.16 {d0[2]}, [%[source]]!\n"
|
|
"vld1.16 {d1[2]}, [r0]!\n"
|
|
"vld1.8 {d0[6]}, [%[source]]\n"
|
|
"vld1.8 {d1[6]}, [r0]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vst1.8 {d0, d1}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d2[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q4, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpadd.u32 d3, d4, d5\n"
|
|
"vpadd.u32 d10, d6, d7\n"
|
|
"vpadd.u32 d12, d3, d10\n"
|
|
"vmul.i32 q6, q6, d2[0]\n"
|
|
"vadd.i32 q6, q6, q4\n"
|
|
"vst1.32 {d12}, [%[destination]:64]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void zip_3x8(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset, std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, r0, %[stride]\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
"vmov.i16 q4, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vld1.8 {d1}, [r0]!\n"
|
|
"vld1.8 {d2}, [r1]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d3[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q5, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpaddl.u16 q4, q4\n"
|
|
"vpadd.u32 d12, d4, d5\n"
|
|
"vpadd.u32 d13, d6, d7\n"
|
|
"vpadd.u32 d14, d8, d9\n"
|
|
"vpadd.u32 d16, d12, d13\n"
|
|
"vpadd.u32 d17, d14, d14\n"
|
|
"vmul.i32 q8, q8, d3[0]\n"
|
|
"vadd.i32 q8, q8, q5\n"
|
|
"vst1.32 {d16}, [%[destination]:64]!\n"
|
|
"vst1.32 {d17[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
|
|
}
|
|
|
|
void zip_3x8_1(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, r0, %[stride]\n"
|
|
"sub %[count], %[count], #1\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
"vmov.i16 q4, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vld1.8 {d1}, [r0]!\n"
|
|
"vld1.8 {d2}, [r1]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vmov.i8 d2, #0\n"
|
|
"vld1.8 {d0[0]}, [%[source]]\n"
|
|
"vld1.8 {d1[0]}, [r0]\n"
|
|
"vld1.8 {d2[0]}, [r1]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d3[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q5, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpaddl.u16 q4, q4\n"
|
|
"vpadd.u32 d12, d4, d5\n"
|
|
"vpadd.u32 d13, d6, d7\n"
|
|
"vpadd.u32 d14, d8, d9\n"
|
|
"vpadd.u32 d16, d12, d13\n"
|
|
"vpadd.u32 d17, d14, d14\n"
|
|
"vmul.i32 q8, q8, d3[0]\n"
|
|
"vadd.i32 q8, q8, q5\n"
|
|
"vst1.32 {d16}, [%[destination]:64]!\n"
|
|
"vst1.32 {d17[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
|
|
}
|
|
|
|
void zip_3x8_2(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, r0, %[stride]\n"
|
|
"sub %[count], %[count], #2\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
"vmov.i16 q4, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vld1.8 {d1}, [r0]!\n"
|
|
"vld1.8 {d2}, [r1]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vmov.i8 d2, #0\n"
|
|
"vld1.16 {d0[0]}, [%[source]]\n"
|
|
"vld1.16 {d1[0]}, [r0]\n"
|
|
"vld1.16 {d2[0]}, [r1]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d3[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q5, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpaddl.u16 q4, q4\n"
|
|
"vpadd.u32 d12, d4, d5\n"
|
|
"vpadd.u32 d13, d6, d7\n"
|
|
"vpadd.u32 d14, d8, d9\n"
|
|
"vpadd.u32 d16, d12, d13\n"
|
|
"vpadd.u32 d17, d14, d14\n"
|
|
"vmul.i32 q8, q8, d3[0]\n"
|
|
"vadd.i32 q8, q8, q5\n"
|
|
"vst1.32 {d16}, [%[destination]:64]!\n"
|
|
"vst1.32 {d17[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
|
|
}
|
|
|
|
void zip_3x8_3(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, r0, %[stride]\n"
|
|
"sub %[count], %[count], #3\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
"vmov.i16 q4, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vld1.8 {d1}, [r0]!\n"
|
|
"vld1.8 {d2}, [r1]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vmov.i8 d2, #0\n"
|
|
"vld1.16 {d0[0]}, [%[source]]!\n"
|
|
"vld1.16 {d1[0]}, [r0]!\n"
|
|
"vld1.16 {d2[0]}, [r1]!\n"
|
|
"vld1.8 {d0[2]}, [%[source]]\n"
|
|
"vld1.8 {d1[2]}, [r0]\n"
|
|
"vld1.8 {d2[2]}, [r1]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d3[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q5, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpaddl.u16 q4, q4\n"
|
|
"vpadd.u32 d12, d4, d5\n"
|
|
"vpadd.u32 d13, d6, d7\n"
|
|
"vpadd.u32 d14, d8, d9\n"
|
|
"vpadd.u32 d16, d12, d13\n"
|
|
"vpadd.u32 d17, d14, d14\n"
|
|
"vmul.i32 q8, q8, d3[0]\n"
|
|
"vadd.i32 q8, q8, q5\n"
|
|
"vst1.32 {d16}, [%[destination]:64]!\n"
|
|
"vst1.32 {d17[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
|
|
}
|
|
|
|
void zip_3x8_4(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, r0, %[stride]\n"
|
|
"sub %[count], %[count], #4\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
"vmov.i16 q4, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vld1.8 {d1}, [r0]!\n"
|
|
"vld1.8 {d2}, [r1]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vmov.i8 d2, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]\n"
|
|
"vld1.32 {d1[0]}, [r0]\n"
|
|
"vld1.32 {d2[0]}, [r1]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d3[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q5, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpaddl.u16 q4, q4\n"
|
|
"vpadd.u32 d12, d4, d5\n"
|
|
"vpadd.u32 d13, d6, d7\n"
|
|
"vpadd.u32 d14, d8, d9\n"
|
|
"vpadd.u32 d16, d12, d13\n"
|
|
"vpadd.u32 d17, d14, d14\n"
|
|
"vmul.i32 q8, q8, d3[0]\n"
|
|
"vadd.i32 q8, q8, q5\n"
|
|
"vst1.32 {d16}, [%[destination]:64]!\n"
|
|
"vst1.32 {d17[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
|
|
}
|
|
|
|
void zip_3x8_5(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, r0, %[stride]\n"
|
|
"sub %[count], %[count], #5\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
"vmov.i16 q4, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vld1.8 {d1}, [r0]!\n"
|
|
"vld1.8 {d2}, [r1]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vmov.i8 d2, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]!\n"
|
|
"vld1.32 {d1[0]}, [r0]!\n"
|
|
"vld1.32 {d2[0]}, [r1]!\n"
|
|
"vld1.8 {d0[4]}, [%[source]]\n"
|
|
"vld1.8 {d1[4]}, [r0]\n"
|
|
"vld1.8 {d2[4]}, [r1]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d3[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q5, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpaddl.u16 q4, q4\n"
|
|
"vpadd.u32 d12, d4, d5\n"
|
|
"vpadd.u32 d13, d6, d7\n"
|
|
"vpadd.u32 d14, d8, d9\n"
|
|
"vpadd.u32 d16, d12, d13\n"
|
|
"vpadd.u32 d17, d14, d14\n"
|
|
"vmul.i32 q8, q8, d3[0]\n"
|
|
"vadd.i32 q8, q8, q5\n"
|
|
"vst1.32 {d16}, [%[destination]:64]!\n"
|
|
"vst1.32 {d17[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
|
|
}
|
|
|
|
void zip_3x8_6(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, r0, %[stride]\n"
|
|
"sub %[count], %[count], #6\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
"vmov.i16 q4, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vld1.8 {d1}, [r0]!\n"
|
|
"vld1.8 {d2}, [r1]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vmov.i8 d2, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]!\n"
|
|
"vld1.32 {d1[0]}, [r0]!\n"
|
|
"vld1.32 {d2[0]}, [r1]!\n"
|
|
"vld1.16 {d0[2]}, [%[source]]\n"
|
|
"vld1.16 {d1[2]}, [r0]\n"
|
|
"vld1.16 {d2[2]}, [r1]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d3[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q5, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpaddl.u16 q4, q4\n"
|
|
"vpadd.u32 d12, d4, d5\n"
|
|
"vpadd.u32 d13, d6, d7\n"
|
|
"vpadd.u32 d14, d8, d9\n"
|
|
"vpadd.u32 d16, d12, d13\n"
|
|
"vpadd.u32 d17, d14, d14\n"
|
|
"vmul.i32 q8, q8, d3[0]\n"
|
|
"vadd.i32 q8, q8, q5\n"
|
|
"vst1.32 {d16}, [%[destination]:64]!\n"
|
|
"vst1.32 {d17[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
|
|
}
|
|
|
|
void zip_3x8_7(const std::uint8_t* source, std::int32_t count,
|
|
std::int32_t stride, std::uint8_t* destination,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t additive_offset) {
|
|
asm volatile(
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, r0, %[stride]\n"
|
|
"sub %[count], %[count], #7\n"
|
|
"vmov.i16 q2, #0\n"
|
|
"vmov.i16 q3, #0\n"
|
|
"vmov.i16 q4, #0\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
// Load Aggregate Store.
|
|
"vld1.8 {d0}, [%[source]]!\n"
|
|
"vld1.8 {d1}, [r0]!\n"
|
|
"vld1.8 {d2}, [r1]!\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
|
|
// Leftover Load Aggregate Store.
|
|
"vmov.i8 d0, #0\n"
|
|
"vmov.i8 d1, #0\n"
|
|
"vmov.i8 d2, #0\n"
|
|
"vld1.32 {d0[0]}, [%[source]]!\n"
|
|
"vld1.32 {d1[0]}, [r0]!\n"
|
|
"vld1.32 {d2[0]}, [r1]!\n"
|
|
"vld1.16 {d0[2]}, [%[source]]!\n"
|
|
"vld1.16 {d1[2]}, [r0]!\n"
|
|
"vld1.16 {d2[2]}, [r1]!\n"
|
|
"vld1.8 {d0[6]}, [%[source]]\n"
|
|
"vld1.8 {d1[6]}, [r0]\n"
|
|
"vld1.8 {d2[6]}, [r1]\n"
|
|
"vaddw.u8 q2, q2, d0\n"
|
|
"vaddw.u8 q3, q3, d1\n"
|
|
"vaddw.u8 q4, q4, d2\n"
|
|
"vst1.8 {d0, d1, d2}, [%[destination]:64]!\n"
|
|
|
|
// Aggregator Reduction.
|
|
"vmov.32 d3[0], %[multiplicative_offset]\n"
|
|
"vdup.32 q5, %[additive_offset]\n"
|
|
"vpaddl.u16 q2, q2\n"
|
|
"vpaddl.u16 q3, q3\n"
|
|
"vpaddl.u16 q4, q4\n"
|
|
"vpadd.u32 d12, d4, d5\n"
|
|
"vpadd.u32 d13, d6, d7\n"
|
|
"vpadd.u32 d14, d8, d9\n"
|
|
"vpadd.u32 d16, d12, d13\n"
|
|
"vpadd.u32 d17, d14, d14\n"
|
|
"vmul.i32 q8, q8, d3[0]\n"
|
|
"vadd.i32 q8, q8, q5\n"
|
|
"vst1.32 {d16}, [%[destination]:64]!\n"
|
|
"vst1.32 {d17[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[additive_offset] "+r"(additive_offset), [stride] "+r"(stride),
|
|
[destination] "+r"(destination), [source] "+r"(source)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d16", "d17", "cc", "memory");
|
|
}
|
|
|
|
inline void mul_1x8_1x8_int32_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d2}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d3}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q2, d3, d2\n"
|
|
"vpadal.u16 q0, q2\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {d8}, [%[rhs]:64]\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d0\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 d0, d0, d8\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0[0]}, [%[result]], %[result_stride]\n"
|
|
: [count] "+r"(count), [result_stride] "+r"(result_stride),
|
|
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d8", "cc", "memory");
|
|
}
|
|
|
|
inline void mul_1x8_2x8_int32_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d4}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d5, d6}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q4, d5, d4\n"
|
|
"vmull.u8 q5, d6, d4\n"
|
|
"vpadal.u16 q0, q4\n"
|
|
"vpadal.u16 q1, q5\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {d8}, [%[rhs]:64]\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d2\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 d0, d0, d8\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0}, [%[result]], %[result_stride]\n"
|
|
: [count] "+r"(count), [result_stride] "+r"(result_stride),
|
|
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "d10", "d11",
|
|
"cc", "memory");
|
|
}
|
|
|
|
inline void mul_1x8_3x8_int32_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
"vmov.i32 q2, #0\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d6}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d7, d8, d9}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q5, d7, d6\n"
|
|
"vmull.u8 q6, d8, d6\n"
|
|
"vmull.u8 q7, d9, d6\n"
|
|
"vpadal.u16 q0, q5\n"
|
|
"vpadal.u16 q1, q6\n"
|
|
"vpadal.u16 q2, q7\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {q4}, [%[rhs]:64]\n"
|
|
|
|
// Change stride because storing in two ops.
|
|
"sub %[result_stride], %[result_stride], #8\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
"vpadd.u32 d4, d4, d5\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d2\n"
|
|
"vpadd.u32 d1, d4, d4\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 q0, q0, q4\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0}, [%[result]]!\n"
|
|
"vst1.32 {d1[0]}, [%[result]], %[result_stride]\n"
|
|
|
|
: [count] "+r"(count), [result_stride] "+r"(result_stride),
|
|
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "d14", "d15", "cc", "memory");
|
|
}
|
|
|
|
inline void mul_2x8_1x8_int32_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d4, d5}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d6}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q4, d6, d4\n"
|
|
"vmull.u8 q5, d6, d5\n"
|
|
"vpadal.u16 q0, q4\n"
|
|
"vpadal.u16 q1, q5\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {d8}, [%[rhs]:64]\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d0\n"
|
|
"vpadd.u32 d2, d2, d2\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 d0, d0, d8\n"
|
|
"vadd.s32 d2, d2, d8\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0[0]}, [%[result]], %[result_stride]\n"
|
|
"vst1.32 {d2[0]}, [%[result]], %[result_stride]\n"
|
|
: [count] "+r"(count), [result_stride] "+r"(result_stride),
|
|
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "d10", "d11",
|
|
"cc", "memory");
|
|
}
|
|
|
|
inline void mul_2x8_2x8_int32_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
"vmov.i32 q2, #0\n"
|
|
"vmov.i32 q3, q0\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d8, d9}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d10, d11}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q6, d10, d8\n"
|
|
"vmull.u8 q7, d11, d8\n"
|
|
"vmull.u8 q8, d10, d9\n"
|
|
"vmull.u8 q9, d11, d9\n"
|
|
"vpadal.u16 q0, q6\n"
|
|
"vpadal.u16 q1, q7\n"
|
|
"vpadal.u16 q2, q8\n"
|
|
"vpadal.u16 q3, q9\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {d8}, [%[rhs]:64]\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
"vpadd.u32 d4, d4, d5\n"
|
|
"vpadd.u32 d6, d6, d7\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d2\n"
|
|
"vpadd.u32 d4, d4, d6\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 d0, d0, d8\n"
|
|
"vadd.s32 d4, d4, d8\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0}, [%[result]], %[result_stride]\n"
|
|
"vst1.32 {d4}, [%[result]], %[result_stride]\n"
|
|
: [count] "+r"(count), [result_stride] "+r"(result_stride),
|
|
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "cc",
|
|
"memory");
|
|
}
|
|
|
|
inline void mul_2x8_3x8_int32_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
"vmov.i32 q2, #0\n"
|
|
"vmov.i32 q3, q0\n"
|
|
"vmov.i32 q4, q1\n"
|
|
"vmov.i32 q5, q2\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d12, d13}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d14, d15, d16}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q9, d14, d12\n"
|
|
"vmull.u8 q10, d15, d12\n"
|
|
"vmull.u8 q11, d16, d12\n"
|
|
"vmull.u8 q12, d14, d13\n"
|
|
"vmull.u8 q13, d15, d13\n"
|
|
"vmull.u8 q14, d16, d13\n"
|
|
"vpadal.u16 q0, q9\n"
|
|
"vpadal.u16 q1, q10\n"
|
|
"vpadal.u16 q2, q11\n"
|
|
"vpadal.u16 q3, q12\n"
|
|
"vpadal.u16 q4, q13\n"
|
|
"vpadal.u16 q5, q14\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {q6}, [%[rhs]:64]\n"
|
|
|
|
// Change stride because storing in two ops.
|
|
"sub %[result_stride], %[result_stride], #8\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
"vpadd.u32 d4, d4, d5\n"
|
|
"vpadd.u32 d6, d6, d7\n"
|
|
"vpadd.u32 d8, d8, d9\n"
|
|
"vpadd.u32 d10, d10, d11\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d2\n"
|
|
"vpadd.u32 d1, d4, d4\n"
|
|
"vpadd.u32 d6, d6, d8\n"
|
|
"vpadd.u32 d7, d10, d10\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 q0, q0, q6\n"
|
|
"vadd.s32 q3, q3, q6\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0}, [%[result]]!\n"
|
|
"vst1.32 {d1[0]}, [%[result]], %[result_stride]\n"
|
|
|
|
"vst1.32 {d6}, [%[result]]!\n"
|
|
"vst1.32 {d7[0]}, [%[result]], %[result_stride]\n"
|
|
|
|
: [count] "+r"(count), [result_stride] "+r"(result_stride),
|
|
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "d14", "d15", "d16", "d18", "d19", "d20", "d21",
|
|
"d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "cc", "memory");
|
|
}
|
|
|
|
inline void mul_3x8_1x8_int32_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
"vmov.i32 q2, #0\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d6, d7, d8}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d9}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q5, d9, d6\n"
|
|
"vmull.u8 q6, d9, d7\n"
|
|
"vmull.u8 q7, d9, d8\n"
|
|
"vpadal.u16 q0, q5\n"
|
|
"vpadal.u16 q1, q6\n"
|
|
"vpadal.u16 q2, q7\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {d8}, [%[rhs]:64]\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
"vpadd.u32 d4, d4, d5\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d0\n"
|
|
"vpadd.u32 d2, d2, d2\n"
|
|
"vpadd.u32 d4, d4, d4\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 d0, d0, d8\n"
|
|
"vadd.s32 d2, d2, d8\n"
|
|
"vadd.s32 d4, d4, d8\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0[0]}, [%[result]], %[result_stride]\n"
|
|
"vst1.32 {d2[0]}, [%[result]], %[result_stride]\n"
|
|
"vst1.32 {d4[0]}, [%[result]], %[result_stride]\n"
|
|
: [count] "+r"(count), [result_stride] "+r"(result_stride),
|
|
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "d14", "d15", "cc", "memory");
|
|
}
|
|
|
|
inline void mul_3x8_2x8_int32_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
"vmov.i32 q2, #0\n"
|
|
"vmov.i32 q3, q0\n"
|
|
"vmov.i32 q4, q1\n"
|
|
"vmov.i32 q5, q2\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d12, d13, d14}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d15, d16}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q9, d15, d12\n"
|
|
"vmull.u8 q10, d16, d12\n"
|
|
"vmull.u8 q11, d15, d13\n"
|
|
"vmull.u8 q12, d16, d13\n"
|
|
"vmull.u8 q13, d15, d14\n"
|
|
"vmull.u8 q14, d16, d14\n"
|
|
"vpadal.u16 q0, q9\n"
|
|
"vpadal.u16 q1, q10\n"
|
|
"vpadal.u16 q2, q11\n"
|
|
"vpadal.u16 q3, q12\n"
|
|
"vpadal.u16 q4, q13\n"
|
|
"vpadal.u16 q5, q14\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {d12}, [%[rhs]:64]\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
"vpadd.u32 d4, d4, d5\n"
|
|
"vpadd.u32 d6, d6, d7\n"
|
|
"vpadd.u32 d8, d8, d9\n"
|
|
"vpadd.u32 d10, d10, d11\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d2\n"
|
|
"vpadd.u32 d4, d4, d6\n"
|
|
"vpadd.u32 d8, d8, d10\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 d0, d0, d12\n"
|
|
"vadd.s32 d4, d4, d12\n"
|
|
"vadd.s32 d8, d8, d12\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0}, [%[result]], %[result_stride]\n"
|
|
"vst1.32 {d4}, [%[result]], %[result_stride]\n"
|
|
"vst1.32 {d8}, [%[result]], %[result_stride]\n"
|
|
: [count] "+r"(count), [result_stride] "+r"(result_stride),
|
|
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "d14", "d15", "d16", "d18", "d19", "d20", "d21",
|
|
"d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "cc", "memory");
|
|
}
|
|
|
|
inline void mul_3x8_3x8_int32_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
"vmov.i32 q2, #0\n"
|
|
"vmov.i32 q3, q0\n"
|
|
"vmov.i32 q4, q1\n"
|
|
"vmov.i32 q5, q2\n"
|
|
"vmov.i32 q6, q3\n"
|
|
"vmov.i32 q7, q4\n"
|
|
"vmov.i32 q8, q5\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// 3x3 lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d18, d19, d20}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d21, d22, d23}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q12, d18, d21\n"
|
|
"vmull.u8 q13, d18, d22\n"
|
|
"vmull.u8 q14, d18, d23\n"
|
|
"vmull.u8 q15, d19, d21\n"
|
|
"vpadal.u16 q0, q12\n"
|
|
"vpadal.u16 q1, q13\n"
|
|
"vpadal.u16 q2, q14\n"
|
|
"vpadal.u16 q3, q15\n"
|
|
"vmull.u8 q12, d19, d22\n"
|
|
"vmull.u8 q13, d19, d23\n"
|
|
"vmull.u8 q14, d20, d21\n"
|
|
"vmull.u8 q15, d20, d22\n"
|
|
"vmull.u8 q9, d20, d23\n"
|
|
"vpadal.u16 q4, q12\n"
|
|
"vpadal.u16 q5, q13\n"
|
|
"vpadal.u16 q6, q14\n"
|
|
"vpadal.u16 q7, q15\n"
|
|
"vpadal.u16 q8, q9\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {q9}, [%[rhs]:64]\n"
|
|
|
|
// Change stride because storing in two ops.
|
|
"sub %[result_stride], %[result_stride], #8\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
"vpadd.u32 d4, d4, d5\n"
|
|
"vpadd.u32 d6, d6, d7\n"
|
|
"vpadd.u32 d8, d8, d9\n"
|
|
"vpadd.u32 d10, d10, d11\n"
|
|
"vpadd.u32 d12, d12, d13\n"
|
|
"vpadd.u32 d14, d14, d15\n"
|
|
"vpadd.u32 d16, d16, d17\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d2\n"
|
|
"vpadd.u32 d1, d4, d4\n"
|
|
"vpadd.u32 d6, d6, d8\n"
|
|
"vpadd.u32 d7, d10, d10\n"
|
|
"vpadd.u32 d12, d12, d14\n"
|
|
"vpadd.u32 d13, d16, d16\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 q0, q0, q9\n"
|
|
"vadd.s32 q3, q3, q9\n"
|
|
"vadd.s32 q6, q6, q9\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0}, [%[result]]!\n"
|
|
"vst1.32 {d1[0]}, [%[result]], %[result_stride]\n"
|
|
|
|
"vst1.32 {d6}, [%[result]]!\n"
|
|
"vst1.32 {d7[0]}, [%[result]], %[result_stride]\n"
|
|
|
|
"vst1.32 {d12}, [%[result]]!\n"
|
|
"vst1.32 {d13[0]}, [%[result]], %[result_stride]\n"
|
|
|
|
: [count] "+r"(count), [result_stride] "+r"(result_stride),
|
|
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20",
|
|
"d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30",
|
|
"d31", "cc", "memory");
|
|
}
|
|
|
|
inline void mul_1x8_1x8_int32_lhsadd_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count,
|
|
std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d2}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d3}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q2, d3, d2\n"
|
|
"vpadal.u16 q0, q2\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {d8}, [%[lhs]:64]\n"
|
|
"vdup.32 d4, d8[0]\n"
|
|
"vld1.32 {d9}, [%[rhs]:64]\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d0\n"
|
|
|
|
// Add lhs offsets to aggregated rows.
|
|
"vadd.s32 d0, d0, d4\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 d0, d0, d9\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0[0]}, [%[result]], %[result_stride]\n"
|
|
: [count] "+r"(count), [result_stride] "+r"(result_stride),
|
|
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d8", "d9", "cc", "memory");
|
|
}
|
|
|
|
inline void mul_1x8_2x8_int32_lhsadd_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count,
|
|
std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d4}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d5, d6}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q4, d5, d4\n"
|
|
"vmull.u8 q5, d6, d4\n"
|
|
"vpadal.u16 q0, q4\n"
|
|
"vpadal.u16 q1, q5\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {d8}, [%[lhs]:64]\n"
|
|
"vdup.32 d4, d8[0]\n"
|
|
"vld1.32 {d9}, [%[rhs]:64]\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d2\n"
|
|
|
|
// Add lhs offsets to aggregated rows.
|
|
"vadd.s32 d0, d0, d4\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 d0, d0, d9\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0}, [%[result]], %[result_stride]\n"
|
|
: [count] "+r"(count), [result_stride] "+r"(result_stride),
|
|
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "d10", "d11",
|
|
"cc", "memory");
|
|
}
|
|
|
|
inline void mul_1x8_3x8_int32_lhsadd_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count,
|
|
std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
"vmov.i32 q2, #0\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d6}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d7, d8, d9}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q5, d7, d6\n"
|
|
"vmull.u8 q6, d8, d6\n"
|
|
"vmull.u8 q7, d9, d6\n"
|
|
"vpadal.u16 q0, q5\n"
|
|
"vpadal.u16 q1, q6\n"
|
|
"vpadal.u16 q2, q7\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {d8}, [%[lhs]:64]\n"
|
|
"vdup.32 q5, d8[0]\n"
|
|
"vld1.32 {q6}, [%[rhs]:64]\n"
|
|
|
|
// Change stride because storing in two ops.
|
|
"sub %[result_stride], %[result_stride], #8\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
"vpadd.u32 d4, d4, d5\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d2\n"
|
|
"vpadd.u32 d1, d4, d4\n"
|
|
|
|
// Add lhs offsets to aggregated rows.
|
|
"vadd.s32 q0, q0, q5\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 q0, q0, q6\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0}, [%[result]]!\n"
|
|
"vst1.32 {d1[0]}, [%[result]], %[result_stride]\n"
|
|
|
|
: [count] "+r"(count), [result_stride] "+r"(result_stride),
|
|
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "d14", "d15", "cc", "memory");
|
|
}
|
|
|
|
inline void mul_2x8_1x8_int32_lhsadd_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count,
|
|
std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d4, d5}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d6}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q4, d6, d4\n"
|
|
"vmull.u8 q5, d6, d5\n"
|
|
"vpadal.u16 q0, q4\n"
|
|
"vpadal.u16 q1, q5\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {d8}, [%[lhs]:64]\n"
|
|
"vdup.32 d4, d8[0]\n"
|
|
"vdup.32 d5, d8[1]\n"
|
|
"vld1.32 {d9}, [%[rhs]:64]\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d0\n"
|
|
"vpadd.u32 d2, d2, d2\n"
|
|
|
|
// Add lhs offsets to aggregated rows.
|
|
"vadd.s32 d0, d0, d4\n"
|
|
"vadd.s32 d2, d2, d5\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 d0, d0, d9\n"
|
|
"vadd.s32 d2, d2, d9\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0[0]}, [%[result]], %[result_stride]\n"
|
|
"vst1.32 {d2[0]}, [%[result]], %[result_stride]\n"
|
|
: [count] "+r"(count), [result_stride] "+r"(result_stride),
|
|
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "d10", "d11",
|
|
"cc", "memory");
|
|
}
|
|
|
|
inline void mul_2x8_2x8_int32_lhsadd_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count,
|
|
std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
"vmov.i32 q2, #0\n"
|
|
"vmov.i32 q3, q0\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d8, d9}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d10, d11}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q6, d10, d8\n"
|
|
"vmull.u8 q7, d11, d8\n"
|
|
"vmull.u8 q8, d10, d9\n"
|
|
"vmull.u8 q9, d11, d9\n"
|
|
"vpadal.u16 q0, q6\n"
|
|
"vpadal.u16 q1, q7\n"
|
|
"vpadal.u16 q2, q8\n"
|
|
"vpadal.u16 q3, q9\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {d8}, [%[lhs]:64]\n"
|
|
"vdup.32 d9, d8[0]\n"
|
|
"vdup.32 d10, d8[1]\n"
|
|
"vld1.32 {d11}, [%[rhs]:64]\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
"vpadd.u32 d4, d4, d5\n"
|
|
"vpadd.u32 d6, d6, d7\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d2\n"
|
|
"vpadd.u32 d4, d4, d6\n"
|
|
|
|
// Add lhs offsets to aggregated rows.
|
|
"vadd.s32 d0, d0, d9\n"
|
|
"vadd.s32 d4, d4, d10\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 d0, d0, d11\n"
|
|
"vadd.s32 d4, d4, d11\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0}, [%[result]], %[result_stride]\n"
|
|
"vst1.32 {d4}, [%[result]], %[result_stride]\n"
|
|
: [count] "+r"(count), [result_stride] "+r"(result_stride),
|
|
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "cc",
|
|
"memory");
|
|
}
|
|
|
|
inline void mul_2x8_3x8_int32_lhsadd_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count,
|
|
std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
"vmov.i32 q2, #0\n"
|
|
"vmov.i32 q3, q0\n"
|
|
"vmov.i32 q4, q1\n"
|
|
"vmov.i32 q5, q2\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d12, d13}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d14, d15, d16}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q9, d14, d12\n"
|
|
"vmull.u8 q10, d15, d12\n"
|
|
"vmull.u8 q11, d16, d12\n"
|
|
"vmull.u8 q12, d14, d13\n"
|
|
"vmull.u8 q13, d15, d13\n"
|
|
"vmull.u8 q14, d16, d13\n"
|
|
"vpadal.u16 q0, q9\n"
|
|
"vpadal.u16 q1, q10\n"
|
|
"vpadal.u16 q2, q11\n"
|
|
"vpadal.u16 q3, q12\n"
|
|
"vpadal.u16 q4, q13\n"
|
|
"vpadal.u16 q5, q14\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {d12}, [%[lhs]:64]\n"
|
|
"vdup.32 q7, d12[0]\n"
|
|
"vdup.32 q8, d12[1]\n"
|
|
"vld1.32 {q9}, [%[rhs]:64]\n"
|
|
|
|
// Change stride because storing in two ops.
|
|
"sub %[result_stride], %[result_stride], #8\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
"vpadd.u32 d4, d4, d5\n"
|
|
"vpadd.u32 d6, d6, d7\n"
|
|
"vpadd.u32 d8, d8, d9\n"
|
|
"vpadd.u32 d10, d10, d11\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d2\n"
|
|
"vpadd.u32 d1, d4, d4\n"
|
|
"vpadd.u32 d6, d6, d8\n"
|
|
"vpadd.u32 d7, d10, d10\n"
|
|
|
|
// Add lhs offsets to aggregated rows.
|
|
"vadd.s32 q0, q0, q7\n"
|
|
"vadd.s32 q3, q3, q8\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 q0, q0, q9\n"
|
|
"vadd.s32 q3, q3, q9\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0}, [%[result]]!\n"
|
|
"vst1.32 {d1[0]}, [%[result]], %[result_stride]\n"
|
|
|
|
"vst1.32 {d6}, [%[result]]!\n"
|
|
"vst1.32 {d7[0]}, [%[result]], %[result_stride]\n"
|
|
|
|
: [count] "+r"(count), [result_stride] "+r"(result_stride),
|
|
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20",
|
|
"d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "cc",
|
|
"memory");
|
|
}
|
|
|
|
inline void mul_3x8_1x8_int32_lhsadd_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count,
|
|
std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
"vmov.i32 q2, #0\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d6, d7, d8}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d9}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q5, d9, d6\n"
|
|
"vmull.u8 q6, d9, d7\n"
|
|
"vmull.u8 q7, d9, d8\n"
|
|
"vpadal.u16 q0, q5\n"
|
|
"vpadal.u16 q1, q6\n"
|
|
"vpadal.u16 q2, q7\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {q4}, [%[lhs]:64]\n"
|
|
"vdup.32 d6, d8[0]\n"
|
|
"vdup.32 d7, d8[1]\n"
|
|
"vdup.32 d10, d9[0]\n"
|
|
"vld1.32 {d11}, [%[rhs]:64]\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
"vpadd.u32 d4, d4, d5\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d0\n"
|
|
"vpadd.u32 d2, d2, d2\n"
|
|
"vpadd.u32 d4, d4, d4\n"
|
|
|
|
// Add lhs offsets to aggregated rows.
|
|
"vadd.s32 d0, d0, d6\n"
|
|
"vadd.s32 d2, d2, d7\n"
|
|
"vadd.s32 d4, d4, d10\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 d0, d0, d11\n"
|
|
"vadd.s32 d2, d2, d11\n"
|
|
"vadd.s32 d4, d4, d11\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0[0]}, [%[result]], %[result_stride]\n"
|
|
"vst1.32 {d2[0]}, [%[result]], %[result_stride]\n"
|
|
"vst1.32 {d4[0]}, [%[result]], %[result_stride]\n"
|
|
: [count] "+r"(count), [result_stride] "+r"(result_stride),
|
|
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "d14", "d15", "cc", "memory");
|
|
}
|
|
|
|
inline void mul_3x8_2x8_int32_lhsadd_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count,
|
|
std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
"vmov.i32 q2, #0\n"
|
|
"vmov.i32 q3, q0\n"
|
|
"vmov.i32 q4, q1\n"
|
|
"vmov.i32 q5, q2\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d12, d13, d14}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d15, d16}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q9, d15, d12\n"
|
|
"vmull.u8 q10, d16, d12\n"
|
|
"vmull.u8 q11, d15, d13\n"
|
|
"vmull.u8 q12, d16, d13\n"
|
|
"vmull.u8 q13, d15, d14\n"
|
|
"vmull.u8 q14, d16, d14\n"
|
|
"vpadal.u16 q0, q9\n"
|
|
"vpadal.u16 q1, q10\n"
|
|
"vpadal.u16 q2, q11\n"
|
|
"vpadal.u16 q3, q12\n"
|
|
"vpadal.u16 q4, q13\n"
|
|
"vpadal.u16 q5, q14\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {q6}, [%[lhs]:64]\n"
|
|
"vdup.32 d14, d12[0]\n"
|
|
"vdup.32 d15, d12[1]\n"
|
|
"vdup.32 d16, d13[0]\n"
|
|
"vld1.32 {d17}, [%[rhs]:64]\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
"vpadd.u32 d4, d4, d5\n"
|
|
"vpadd.u32 d6, d6, d7\n"
|
|
"vpadd.u32 d8, d8, d9\n"
|
|
"vpadd.u32 d10, d10, d11\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d2\n"
|
|
"vpadd.u32 d4, d4, d6\n"
|
|
"vpadd.u32 d8, d8, d10\n"
|
|
|
|
// Add lhs offsets to aggregated rows.
|
|
"vadd.s32 d0, d0, d14\n"
|
|
"vadd.s32 d4, d4, d15\n"
|
|
"vadd.s32 d8, d8, d16\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 d0, d0, d17\n"
|
|
"vadd.s32 d4, d4, d17\n"
|
|
"vadd.s32 d8, d8, d17\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0}, [%[result]], %[result_stride]\n"
|
|
"vst1.32 {d4}, [%[result]], %[result_stride]\n"
|
|
"vst1.32 {d8}, [%[result]], %[result_stride]\n"
|
|
: [count] "+r"(count), [result_stride] "+r"(result_stride),
|
|
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20",
|
|
"d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "cc",
|
|
"memory");
|
|
}
|
|
|
|
inline void mul_3x8_3x8_int32_lhsadd_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count,
|
|
std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
"vmov.i32 q2, #0\n"
|
|
"vmov.i32 q3, q0\n"
|
|
"vmov.i32 q4, q1\n"
|
|
"vmov.i32 q5, q2\n"
|
|
"vmov.i32 q6, q3\n"
|
|
"vmov.i32 q7, q4\n"
|
|
"vmov.i32 q8, q5\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// 3x3 lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d18, d19, d20}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d21, d22, d23}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q12, d18, d21\n"
|
|
"vmull.u8 q13, d18, d22\n"
|
|
"vmull.u8 q14, d18, d23\n"
|
|
"vmull.u8 q15, d19, d21\n"
|
|
"vpadal.u16 q0, q12\n"
|
|
"vpadal.u16 q1, q13\n"
|
|
"vpadal.u16 q2, q14\n"
|
|
"vpadal.u16 q3, q15\n"
|
|
"vmull.u8 q12, d19, d22\n"
|
|
"vmull.u8 q13, d19, d23\n"
|
|
"vmull.u8 q14, d20, d21\n"
|
|
"vmull.u8 q15, d20, d22\n"
|
|
"vmull.u8 q9, d20, d23\n"
|
|
"vpadal.u16 q4, q12\n"
|
|
"vpadal.u16 q5, q13\n"
|
|
"vpadal.u16 q6, q14\n"
|
|
"vpadal.u16 q7, q15\n"
|
|
"vpadal.u16 q8, q9\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {q9}, [%[lhs]:64]\n"
|
|
"vdup.32 q10, d18[0]\n"
|
|
"vdup.32 q11, d18[1]\n"
|
|
"vdup.32 q12, d19[0]\n"
|
|
"vld1.32 {q13}, [%[rhs]:64]\n"
|
|
|
|
// Change stride because storing in two ops.
|
|
"sub %[result_stride], %[result_stride], #8\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
"vpadd.u32 d4, d4, d5\n"
|
|
"vpadd.u32 d6, d6, d7\n"
|
|
"vpadd.u32 d8, d8, d9\n"
|
|
"vpadd.u32 d10, d10, d11\n"
|
|
"vpadd.u32 d12, d12, d13\n"
|
|
"vpadd.u32 d14, d14, d15\n"
|
|
"vpadd.u32 d16, d16, d17\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d2\n"
|
|
"vpadd.u32 d1, d4, d4\n"
|
|
"vpadd.u32 d6, d6, d8\n"
|
|
"vpadd.u32 d7, d10, d10\n"
|
|
"vpadd.u32 d12, d12, d14\n"
|
|
"vpadd.u32 d13, d16, d16\n"
|
|
|
|
// Add lhs offsets to aggregated rows.
|
|
"vadd.s32 q0, q0, q10\n"
|
|
"vadd.s32 q3, q3, q11\n"
|
|
"vadd.s32 q6, q6, q12\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 q0, q0, q13\n"
|
|
"vadd.s32 q3, q3, q13\n"
|
|
"vadd.s32 q6, q6, q13\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0}, [%[result]]!\n"
|
|
"vst1.32 {d1[0]}, [%[result]], %[result_stride]\n"
|
|
|
|
"vst1.32 {d6}, [%[result]]!\n"
|
|
"vst1.32 {d7[0]}, [%[result]], %[result_stride]\n"
|
|
|
|
"vst1.32 {d12}, [%[result]]!\n"
|
|
"vst1.32 {d13[0]}, [%[result]], %[result_stride]\n"
|
|
|
|
: [count] "+r"(count), [result_stride] "+r"(result_stride),
|
|
[rhs] "+r"(rhs), [lhs] "+r"(lhs), [result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20",
|
|
"d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30",
|
|
"d31", "cc", "memory");
|
|
}
|
|
|
|
inline void mul_1x8_1x8_float_lhsadd_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count, float* result,
|
|
std::int32_t result_stride,
|
|
float result_scale) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d2}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d3}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q2, d3, d2\n"
|
|
"vpadal.u16 q0, q2\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {d8}, [%[lhs]:64]\n"
|
|
"vdup.32 d4, d8[0]\n"
|
|
"vld1.32 {d9}, [%[rhs]:64]\n"
|
|
"vdup.32 d5, %[result_scale]\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d0\n"
|
|
|
|
// Add lhs offsets to aggregated rows.
|
|
"vadd.s32 d0, d0, d4\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 d0, d0, d9\n"
|
|
|
|
// Convert to float. Multiply by result scale.
|
|
"vcvt.f32.s32 d0, d0\n"
|
|
"vmul.f32 d0, d0, d5\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0[0]}, [%[result]], %[result_stride]\n"
|
|
: [count] "+r"(count), [result_scale] "+r"(result_scale),
|
|
[result_stride] "+r"(result_stride), [rhs] "+r"(rhs), [lhs] "+r"(lhs),
|
|
[result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d8", "d9", "cc", "memory");
|
|
}
|
|
|
|
inline void mul_1x8_2x8_float_lhsadd_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count, float* result,
|
|
std::int32_t result_stride,
|
|
float result_scale) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d4}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d5, d6}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q4, d5, d4\n"
|
|
"vmull.u8 q5, d6, d4\n"
|
|
"vpadal.u16 q0, q4\n"
|
|
"vpadal.u16 q1, q5\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {d8}, [%[lhs]:64]\n"
|
|
"vdup.32 d4, d8[0]\n"
|
|
"vld1.32 {d9}, [%[rhs]:64]\n"
|
|
"vdup.32 d5, %[result_scale]\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d2\n"
|
|
|
|
// Add lhs offsets to aggregated rows.
|
|
"vadd.s32 d0, d0, d4\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 d0, d0, d9\n"
|
|
|
|
// Convert to float. Multiply by result scale.
|
|
"vcvt.f32.s32 d0, d0\n"
|
|
"vmul.f32 d0, d0, d5\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0}, [%[result]], %[result_stride]\n"
|
|
: [count] "+r"(count), [result_scale] "+r"(result_scale),
|
|
[result_stride] "+r"(result_stride), [rhs] "+r"(rhs), [lhs] "+r"(lhs),
|
|
[result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "d10", "d11",
|
|
"cc", "memory");
|
|
}
|
|
|
|
inline void mul_1x8_3x8_float_lhsadd_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count, float* result,
|
|
std::int32_t result_stride,
|
|
float result_scale) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
"vmov.i32 q2, #0\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d6}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d7, d8, d9}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q5, d7, d6\n"
|
|
"vmull.u8 q6, d8, d6\n"
|
|
"vmull.u8 q7, d9, d6\n"
|
|
"vpadal.u16 q0, q5\n"
|
|
"vpadal.u16 q1, q6\n"
|
|
"vpadal.u16 q2, q7\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {d8}, [%[lhs]:64]\n"
|
|
"vdup.32 q5, d8[0]\n"
|
|
"vld1.32 {q6}, [%[rhs]:64]\n"
|
|
"vdup.32 q7, %[result_scale]\n"
|
|
|
|
// Change stride because storing in two ops.
|
|
"sub %[result_stride], %[result_stride], #8\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
"vpadd.u32 d4, d4, d5\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d2\n"
|
|
"vpadd.u32 d1, d4, d4\n"
|
|
|
|
// Add lhs offsets to aggregated rows.
|
|
"vadd.s32 q0, q0, q5\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 q0, q0, q6\n"
|
|
|
|
// Convert to float. Multiply by result scale.
|
|
"vcvt.f32.s32 q0, q0\n"
|
|
"vmul.f32 q0, q0, q7\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0}, [%[result]]!\n"
|
|
"vst1.32 {d1[0]}, [%[result]], %[result_stride]\n"
|
|
|
|
: [count] "+r"(count), [result_scale] "+r"(result_scale),
|
|
[result_stride] "+r"(result_stride), [rhs] "+r"(rhs), [lhs] "+r"(lhs),
|
|
[result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "d14", "d15", "cc", "memory");
|
|
}
|
|
|
|
inline void mul_2x8_1x8_float_lhsadd_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count, float* result,
|
|
std::int32_t result_stride,
|
|
float result_scale) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d4, d5}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d6}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q4, d6, d4\n"
|
|
"vmull.u8 q5, d6, d5\n"
|
|
"vpadal.u16 q0, q4\n"
|
|
"vpadal.u16 q1, q5\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {d8}, [%[lhs]:64]\n"
|
|
"vdup.32 d4, d8[0]\n"
|
|
"vdup.32 d5, d8[1]\n"
|
|
"vld1.32 {d9}, [%[rhs]:64]\n"
|
|
"vdup.32 d6, %[result_scale]\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d0\n"
|
|
"vpadd.u32 d2, d2, d2\n"
|
|
|
|
// Add lhs offsets to aggregated rows.
|
|
"vadd.s32 d0, d0, d4\n"
|
|
"vadd.s32 d2, d2, d5\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 d0, d0, d9\n"
|
|
"vadd.s32 d2, d2, d9\n"
|
|
|
|
// Convert to float. Multiply by result scale.
|
|
"vcvt.f32.s32 d0, d0\n"
|
|
"vcvt.f32.s32 d2, d2\n"
|
|
"vmul.f32 d0, d0, d6\n"
|
|
"vmul.f32 d2, d2, d6\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0[0]}, [%[result]], %[result_stride]\n"
|
|
"vst1.32 {d2[0]}, [%[result]], %[result_stride]\n"
|
|
: [count] "+r"(count), [result_scale] "+r"(result_scale),
|
|
[result_stride] "+r"(result_stride), [rhs] "+r"(rhs), [lhs] "+r"(lhs),
|
|
[result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d8", "d9", "d10", "d11",
|
|
"cc", "memory");
|
|
}
|
|
|
|
inline void mul_2x8_2x8_float_lhsadd_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count, float* result,
|
|
std::int32_t result_stride,
|
|
float result_scale) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
"vmov.i32 q2, #0\n"
|
|
"vmov.i32 q3, q0\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d8, d9}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d10, d11}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q6, d10, d8\n"
|
|
"vmull.u8 q7, d11, d8\n"
|
|
"vmull.u8 q8, d10, d9\n"
|
|
"vmull.u8 q9, d11, d9\n"
|
|
"vpadal.u16 q0, q6\n"
|
|
"vpadal.u16 q1, q7\n"
|
|
"vpadal.u16 q2, q8\n"
|
|
"vpadal.u16 q3, q9\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {d8}, [%[lhs]:64]\n"
|
|
"vdup.32 d9, d8[0]\n"
|
|
"vdup.32 d10, d8[1]\n"
|
|
"vld1.32 {d11}, [%[rhs]:64]\n"
|
|
"vdup.32 d12, %[result_scale]\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
"vpadd.u32 d4, d4, d5\n"
|
|
"vpadd.u32 d6, d6, d7\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d2\n"
|
|
"vpadd.u32 d4, d4, d6\n"
|
|
|
|
// Add lhs offsets to aggregated rows.
|
|
"vadd.s32 d0, d0, d9\n"
|
|
"vadd.s32 d4, d4, d10\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 d0, d0, d11\n"
|
|
"vadd.s32 d4, d4, d11\n"
|
|
|
|
// Convert to float. Multiply by result scale.
|
|
"vcvt.f32.s32 d0, d0\n"
|
|
"vcvt.f32.s32 d4, d4\n"
|
|
"vmul.f32 d0, d0, d12\n"
|
|
"vmul.f32 d4, d4, d12\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0}, [%[result]], %[result_stride]\n"
|
|
"vst1.32 {d4}, [%[result]], %[result_stride]\n"
|
|
: [count] "+r"(count), [result_scale] "+r"(result_scale),
|
|
[result_stride] "+r"(result_stride), [rhs] "+r"(rhs), [lhs] "+r"(lhs),
|
|
[result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "cc",
|
|
"memory");
|
|
}
|
|
|
|
inline void mul_2x8_3x8_float_lhsadd_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count, float* result,
|
|
std::int32_t result_stride,
|
|
float result_scale) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
"vmov.i32 q2, #0\n"
|
|
"vmov.i32 q3, q0\n"
|
|
"vmov.i32 q4, q1\n"
|
|
"vmov.i32 q5, q2\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d12, d13}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d14, d15, d16}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q9, d14, d12\n"
|
|
"vmull.u8 q10, d15, d12\n"
|
|
"vmull.u8 q11, d16, d12\n"
|
|
"vmull.u8 q12, d14, d13\n"
|
|
"vmull.u8 q13, d15, d13\n"
|
|
"vmull.u8 q14, d16, d13\n"
|
|
"vpadal.u16 q0, q9\n"
|
|
"vpadal.u16 q1, q10\n"
|
|
"vpadal.u16 q2, q11\n"
|
|
"vpadal.u16 q3, q12\n"
|
|
"vpadal.u16 q4, q13\n"
|
|
"vpadal.u16 q5, q14\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {d12}, [%[lhs]:64]\n"
|
|
"vdup.32 q7, d12[0]\n"
|
|
"vdup.32 q8, d12[1]\n"
|
|
"vld1.32 {q9}, [%[rhs]:64]\n"
|
|
"vdup.32 q10, %[result_scale]\n"
|
|
|
|
// Change stride because storing in two ops.
|
|
"sub %[result_stride], %[result_stride], #8\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
"vpadd.u32 d4, d4, d5\n"
|
|
"vpadd.u32 d6, d6, d7\n"
|
|
"vpadd.u32 d8, d8, d9\n"
|
|
"vpadd.u32 d10, d10, d11\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d2\n"
|
|
"vpadd.u32 d1, d4, d4\n"
|
|
"vpadd.u32 d6, d6, d8\n"
|
|
"vpadd.u32 d7, d10, d10\n"
|
|
|
|
// Add lhs offsets to aggregated rows.
|
|
"vadd.s32 q0, q0, q7\n"
|
|
"vadd.s32 q3, q3, q8\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 q0, q0, q9\n"
|
|
"vadd.s32 q3, q3, q9\n"
|
|
|
|
// Convert to float. Multiply by result scale.
|
|
"vcvt.f32.s32 q0, q0\n"
|
|
"vcvt.f32.s32 q3, q3\n"
|
|
"vmul.f32 q0, q0, q10\n"
|
|
"vmul.f32 q3, q3, q10\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0}, [%[result]]!\n"
|
|
"vst1.32 {d1[0]}, [%[result]], %[result_stride]\n"
|
|
|
|
"vst1.32 {d6}, [%[result]]!\n"
|
|
"vst1.32 {d7[0]}, [%[result]], %[result_stride]\n"
|
|
|
|
: [count] "+r"(count), [result_scale] "+r"(result_scale),
|
|
[result_stride] "+r"(result_stride), [rhs] "+r"(rhs), [lhs] "+r"(lhs),
|
|
[result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20",
|
|
"d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "cc",
|
|
"memory");
|
|
}
|
|
|
|
inline void mul_3x8_1x8_float_lhsadd_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count, float* result,
|
|
std::int32_t result_stride,
|
|
float result_scale) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
"vmov.i32 q2, #0\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d6, d7, d8}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d9}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q5, d9, d6\n"
|
|
"vmull.u8 q6, d9, d7\n"
|
|
"vmull.u8 q7, d9, d8\n"
|
|
"vpadal.u16 q0, q5\n"
|
|
"vpadal.u16 q1, q6\n"
|
|
"vpadal.u16 q2, q7\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {q4}, [%[lhs]:64]\n"
|
|
"vdup.32 d6, d8[0]\n"
|
|
"vdup.32 d7, d8[1]\n"
|
|
"vdup.32 d10, d9[0]\n"
|
|
"vld1.32 {d11}, [%[rhs]:64]\n"
|
|
"vdup.32 d12, %[result_scale]\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
"vpadd.u32 d4, d4, d5\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d0\n"
|
|
"vpadd.u32 d2, d2, d2\n"
|
|
"vpadd.u32 d4, d4, d4\n"
|
|
|
|
// Add lhs offsets to aggregated rows.
|
|
"vadd.s32 d0, d0, d6\n"
|
|
"vadd.s32 d2, d2, d7\n"
|
|
"vadd.s32 d4, d4, d10\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 d0, d0, d11\n"
|
|
"vadd.s32 d2, d2, d11\n"
|
|
"vadd.s32 d4, d4, d11\n"
|
|
|
|
// Convert to float. Multiply by result scale.
|
|
"vcvt.f32.s32 d0, d0\n"
|
|
"vcvt.f32.s32 d2, d2\n"
|
|
"vcvt.f32.s32 d4, d4\n"
|
|
"vmul.f32 d0, d0, d12\n"
|
|
"vmul.f32 d2, d2, d12\n"
|
|
"vmul.f32 d4, d4, d12\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0[0]}, [%[result]], %[result_stride]\n"
|
|
"vst1.32 {d2[0]}, [%[result]], %[result_stride]\n"
|
|
"vst1.32 {d4[0]}, [%[result]], %[result_stride]\n"
|
|
: [count] "+r"(count), [result_scale] "+r"(result_scale),
|
|
[result_stride] "+r"(result_stride), [rhs] "+r"(rhs), [lhs] "+r"(lhs),
|
|
[result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "d14", "d15", "cc", "memory");
|
|
}
|
|
|
|
inline void mul_3x8_2x8_float_lhsadd_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count, float* result,
|
|
std::int32_t result_stride,
|
|
float result_scale) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
"vmov.i32 q2, #0\n"
|
|
"vmov.i32 q3, q0\n"
|
|
"vmov.i32 q4, q1\n"
|
|
"vmov.i32 q5, q2\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// General NxM lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d12, d13, d14}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d15, d16}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q9, d15, d12\n"
|
|
"vmull.u8 q10, d16, d12\n"
|
|
"vmull.u8 q11, d15, d13\n"
|
|
"vmull.u8 q12, d16, d13\n"
|
|
"vmull.u8 q13, d15, d14\n"
|
|
"vmull.u8 q14, d16, d14\n"
|
|
"vpadal.u16 q0, q9\n"
|
|
"vpadal.u16 q1, q10\n"
|
|
"vpadal.u16 q2, q11\n"
|
|
"vpadal.u16 q3, q12\n"
|
|
"vpadal.u16 q4, q13\n"
|
|
"vpadal.u16 q5, q14\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {q6}, [%[lhs]:64]\n"
|
|
"vdup.32 d14, d12[0]\n"
|
|
"vdup.32 d15, d12[1]\n"
|
|
"vdup.32 d16, d13[0]\n"
|
|
"vld1.32 {d17}, [%[rhs]:64]\n"
|
|
"vdup.32 d18, %[result_scale]\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
"vpadd.u32 d4, d4, d5\n"
|
|
"vpadd.u32 d6, d6, d7\n"
|
|
"vpadd.u32 d8, d8, d9\n"
|
|
"vpadd.u32 d10, d10, d11\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d2\n"
|
|
"vpadd.u32 d4, d4, d6\n"
|
|
"vpadd.u32 d8, d8, d10\n"
|
|
|
|
// Add lhs offsets to aggregated rows.
|
|
"vadd.s32 d0, d0, d14\n"
|
|
"vadd.s32 d4, d4, d15\n"
|
|
"vadd.s32 d8, d8, d16\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 d0, d0, d17\n"
|
|
"vadd.s32 d4, d4, d17\n"
|
|
"vadd.s32 d8, d8, d17\n"
|
|
|
|
// Convert to float. Multiply by result scale.
|
|
"vcvt.f32.s32 d0, d0\n"
|
|
"vcvt.f32.s32 d4, d4\n"
|
|
"vcvt.f32.s32 d8, d8\n"
|
|
"vmul.f32 d0, d0, d18\n"
|
|
"vmul.f32 d4, d4, d18\n"
|
|
"vmul.f32 d8, d8, d18\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0}, [%[result]], %[result_stride]\n"
|
|
"vst1.32 {d4}, [%[result]], %[result_stride]\n"
|
|
"vst1.32 {d8}, [%[result]], %[result_stride]\n"
|
|
: [count] "+r"(count), [result_scale] "+r"(result_scale),
|
|
[result_stride] "+r"(result_stride), [rhs] "+r"(rhs), [lhs] "+r"(lhs),
|
|
[result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20",
|
|
"d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "cc",
|
|
"memory");
|
|
}
|
|
|
|
inline void mul_3x8_3x8_float_lhsadd_rhsadd(const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs,
|
|
std::int32_t count, float* result,
|
|
std::int32_t result_stride,
|
|
float result_scale) {
|
|
asm volatile(
|
|
// Clear aggregators.
|
|
"vmov.i32 q0, #0\n"
|
|
"vmov.i32 q1, #0\n"
|
|
"vmov.i32 q2, #0\n"
|
|
"vmov.i32 q3, q0\n"
|
|
"vmov.i32 q4, q1\n"
|
|
"vmov.i32 q5, q2\n"
|
|
"vmov.i32 q6, q3\n"
|
|
"vmov.i32 q7, q4\n"
|
|
"vmov.i32 q8, q5\n"
|
|
|
|
"pld [%[lhs]]\n"
|
|
"pld [%[rhs]]\n"
|
|
// 3x3 lanes loop.
|
|
"1:"
|
|
|
|
// Subtract counter.
|
|
"subs %[count], %[count], #8\n"
|
|
|
|
"vld1.8 {d18, d19, d20}, [%[lhs]:64]!\n"
|
|
"vld1.8 {d21, d22, d23}, [%[rhs]:64]!\n"
|
|
"pld [%[lhs], #64]\n"
|
|
"pld [%[rhs], #64]\n"
|
|
"vmull.u8 q12, d18, d21\n"
|
|
"vmull.u8 q13, d18, d22\n"
|
|
"vmull.u8 q14, d18, d23\n"
|
|
"vmull.u8 q15, d19, d21\n"
|
|
"vpadal.u16 q0, q12\n"
|
|
"vpadal.u16 q1, q13\n"
|
|
"vpadal.u16 q2, q14\n"
|
|
"vpadal.u16 q3, q15\n"
|
|
"vmull.u8 q12, d19, d22\n"
|
|
"vmull.u8 q13, d19, d23\n"
|
|
"vmull.u8 q14, d20, d21\n"
|
|
"vmull.u8 q15, d20, d22\n"
|
|
"vmull.u8 q9, d20, d23\n"
|
|
"vpadal.u16 q4, q12\n"
|
|
"vpadal.u16 q5, q13\n"
|
|
"vpadal.u16 q6, q14\n"
|
|
"vpadal.u16 q7, q15\n"
|
|
"vpadal.u16 q8, q9\n"
|
|
|
|
// Loop break.
|
|
"bne 1b\n"
|
|
|
|
"vld1.32 {q9}, [%[lhs]:64]\n"
|
|
"vdup.32 q10, d18[0]\n"
|
|
"vdup.32 q11, d18[1]\n"
|
|
"vdup.32 q12, d19[0]\n"
|
|
"vld1.32 {q13}, [%[rhs]:64]\n"
|
|
"vdup.32 q14, %[result_scale]\n"
|
|
|
|
// Change stride because storing in two ops.
|
|
"sub %[result_stride], %[result_stride], #8\n"
|
|
|
|
// Horizontal reduce aggregators.
|
|
"vpadd.u32 d0, d0, d1\n"
|
|
"vpadd.u32 d2, d2, d3\n"
|
|
"vpadd.u32 d4, d4, d5\n"
|
|
"vpadd.u32 d6, d6, d7\n"
|
|
"vpadd.u32 d8, d8, d9\n"
|
|
"vpadd.u32 d10, d10, d11\n"
|
|
"vpadd.u32 d12, d12, d13\n"
|
|
"vpadd.u32 d14, d14, d15\n"
|
|
"vpadd.u32 d16, d16, d17\n"
|
|
|
|
// Reduce rows.
|
|
"vpadd.u32 d0, d0, d2\n"
|
|
"vpadd.u32 d1, d4, d4\n"
|
|
"vpadd.u32 d6, d6, d8\n"
|
|
"vpadd.u32 d7, d10, d10\n"
|
|
"vpadd.u32 d12, d12, d14\n"
|
|
"vpadd.u32 d13, d16, d16\n"
|
|
|
|
// Add lhs offsets to aggregated rows.
|
|
"vadd.s32 q0, q0, q10\n"
|
|
"vadd.s32 q3, q3, q11\n"
|
|
"vadd.s32 q6, q6, q12\n"
|
|
|
|
// Add rhs offset to aggregated rows.
|
|
"vadd.s32 q0, q0, q13\n"
|
|
"vadd.s32 q3, q3, q13\n"
|
|
"vadd.s32 q6, q6, q13\n"
|
|
|
|
// Convert to float. Multiply by result scale.
|
|
"vcvt.f32.s32 q0, q0\n"
|
|
"vcvt.f32.s32 q3, q3\n"
|
|
"vcvt.f32.s32 q6, q6\n"
|
|
"vmul.f32 q0, q0, q14\n"
|
|
"vmul.f32 q3, q3, q14\n"
|
|
"vmul.f32 q6, q6, q14\n"
|
|
|
|
// Store reduced rows.
|
|
"vst1.32 {d0}, [%[result]]!\n"
|
|
"vst1.32 {d1[0]}, [%[result]], %[result_stride]\n"
|
|
|
|
"vst1.32 {d6}, [%[result]]!\n"
|
|
"vst1.32 {d7[0]}, [%[result]], %[result_stride]\n"
|
|
|
|
"vst1.32 {d12}, [%[result]]!\n"
|
|
"vst1.32 {d13[0]}, [%[result]], %[result_stride]\n"
|
|
|
|
: [count] "+r"(count), [result_scale] "+r"(result_scale),
|
|
[result_stride] "+r"(result_stride), [rhs] "+r"(rhs), [lhs] "+r"(lhs),
|
|
[result] "+r"(result)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20",
|
|
"d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30",
|
|
"d31", "cc", "memory");
|
|
}
|
|
|
|
void qnt_1x8_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.8 {d12}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void qnt_1x8_1_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"subs %[count], %[count], #1\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.8 {d12}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d8[0]}, [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.8 {d12[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void qnt_1x8_2_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"subs %[count], %[count], #2\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.8 {d12}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d8}, [%[source]:64]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.16 {d12[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void qnt_1x8_3_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"subs %[count], %[count], #3\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.8 {d12}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d8}, [%[source]:64]!\n"
|
|
"vld1.32 {d9[0]}, [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.16 {d12[0]}, [%[destination]]!\n"
|
|
"vst1.8 {d12[2]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void qnt_1x8_4_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"subs %[count], %[count], #4\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.8 {d12}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d8, d9}, [%[source]:64]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.32 {d12[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void qnt_1x8_5_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"subs %[count], %[count], #5\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.8 {d12}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d8, d9}, [%[source]:64]!\n"
|
|
"vld1.32 {d10[0]}, [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.32 {d12[0]}, [%[destination]]!\n"
|
|
"vst1.8 {d12[4]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void qnt_1x8_6_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"subs %[count], %[count], #6\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.8 {d12}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d8, d9, d10}, [%[source]:64]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.32 {d12[0]}, [%[destination]]!\n"
|
|
"vst1.16 {d12[2]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void qnt_1x8_7_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"subs %[count], %[count], #7\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.8 {d12}, [%[destination]:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d8, d9, d10}, [%[source]:64]!\n"
|
|
"vld1.32 {d11[0]}, [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.32 {d12[0]}, [%[destination]]!\n"
|
|
"vst1.16 {d12[2]}, [%[destination]]!\n"
|
|
"vst1.8 {d12[6]}, [%[destination]]!\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void qnt_2x8_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
|
|
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.8 {d18}, [%[destination]:64]!\n"
|
|
"vst1.8 {d20}, [r1:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
|
|
"d20", "d21", "cc", "memory");
|
|
}
|
|
|
|
void qnt_2x8_1_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"subs %[count], %[count], #1\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
|
|
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.8 {d18}, [%[destination]:64]!\n"
|
|
"vst1.8 {d20}, [r1:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d10[0]}, [%[source]]\n"
|
|
"vld1.32 {d14[0]}, [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.8 {d18[0]}, [%[destination]]\n"
|
|
"vst1.8 {d20[0]}, [r1]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
|
|
"d20", "d21", "cc", "memory");
|
|
}
|
|
|
|
void qnt_2x8_2_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"subs %[count], %[count], #2\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
|
|
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.8 {d18}, [%[destination]:64]!\n"
|
|
"vst1.8 {d20}, [r1:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d10}, [%[source]:64]\n"
|
|
"vld1.32 {d14}, [r0:64]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.16 {d18[0]}, [%[destination]]\n"
|
|
"vst1.16 {d20[0]}, [r1]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
|
|
"d20", "d21", "cc", "memory");
|
|
}
|
|
|
|
void qnt_2x8_3_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"subs %[count], %[count], #3\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
|
|
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.8 {d18}, [%[destination]:64]!\n"
|
|
"vst1.8 {d20}, [r1:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d10}, [%[source]:64]!\n"
|
|
"vld1.32 {d14}, [r0:64]!\n"
|
|
"vld1.32 {d11[0]}, [%[source]]\n"
|
|
"vld1.32 {d15[0]}, [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.16 {d18[0]}, [%[destination]]!\n"
|
|
"vst1.16 {d20[0]}, [r1]!\n"
|
|
"vst1.8 {d18[2]}, [%[destination]]\n"
|
|
"vst1.8 {d20[2]}, [r1]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
|
|
"d20", "d21", "cc", "memory");
|
|
}
|
|
|
|
void qnt_2x8_4_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"subs %[count], %[count], #4\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
|
|
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.8 {d18}, [%[destination]:64]!\n"
|
|
"vst1.8 {d20}, [r1:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d10, d11}, [%[source]:64]\n"
|
|
"vld1.32 {d14, d15}, [r0:64]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.32 {d18[0]}, [%[destination]]\n"
|
|
"vst1.32 {d20[0]}, [r1]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
|
|
"d20", "d21", "cc", "memory");
|
|
}
|
|
|
|
void qnt_2x8_5_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"subs %[count], %[count], #5\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
|
|
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.8 {d18}, [%[destination]:64]!\n"
|
|
"vst1.8 {d20}, [r1:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d10, d11}, [%[source]:64]!\n"
|
|
"vld1.32 {d14, d15}, [r0:64]!\n"
|
|
"vld1.32 {d12[0]}, [%[source]]\n"
|
|
"vld1.32 {d16[0]}, [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.32 {d18[0]}, [%[destination]]!\n"
|
|
"vst1.32 {d20[0]}, [r1]!\n"
|
|
"vst1.8 {d18[4]}, [%[destination]]\n"
|
|
"vst1.8 {d20[4]}, [r1]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
|
|
"d20", "d21", "cc", "memory");
|
|
}
|
|
|
|
void qnt_2x8_6_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"subs %[count], %[count], #6\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
|
|
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.8 {d18}, [%[destination]:64]!\n"
|
|
"vst1.8 {d20}, [r1:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d10, d11, d12}, [%[source]:64]\n"
|
|
"vld1.32 {d14, d15, d16}, [r0:64]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.32 {d18[0]}, [%[destination]]!\n"
|
|
"vst1.32 {d20[0]}, [r1]!\n"
|
|
"vst1.16 {d18[2]}, [%[destination]]\n"
|
|
"vst1.16 {d20[2]}, [r1]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
|
|
"d20", "d21", "cc", "memory");
|
|
}
|
|
|
|
void qnt_2x8_7_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"subs %[count], %[count], #7\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
|
|
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.8 {d18}, [%[destination]:64]!\n"
|
|
"vst1.8 {d20}, [r1:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d10, d11, d12}, [%[source]:64]!\n"
|
|
"vld1.32 {d14, d15, d16}, [r0:64]!\n"
|
|
"vld1.32 {d13[0]}, [%[source]]\n"
|
|
"vld1.32 {d17[0]}, [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.32 {d18[0]}, [%[destination]]!\n"
|
|
"vst1.32 {d20[0]}, [r1]!\n"
|
|
"vst1.16 {d18[2]}, [%[destination]]!\n"
|
|
"vst1.16 {d20[2]}, [r1]!\n"
|
|
"vst1.8 {d18[6]}, [%[destination]]!\n"
|
|
"vst1.8 {d20[6]}, [r1]!\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
|
|
"d20", "d21", "cc", "memory");
|
|
}
|
|
|
|
void qnt_3x8_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"add r2, r0, %[stride]\n"
|
|
"add r3, r1, %[destination_stride]\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
|
|
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
|
|
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"pld [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.8 {d24}, [%[destination]:64]!\n"
|
|
"vst1.8 {d26}, [r1:64]!\n"
|
|
"vst1.8 {d28}, [r3:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
|
|
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
|
|
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
|
|
"d28", "d29", "cc", "memory");
|
|
}
|
|
|
|
void qnt_3x8_1_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"add r2, r0, %[stride]\n"
|
|
"add r3, r1, %[destination_stride]\n"
|
|
"subs %[count], %[count], #1\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
|
|
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
|
|
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"pld [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.8 {d24}, [%[destination]:64]!\n"
|
|
"vst1.8 {d26}, [r1:64]!\n"
|
|
"vst1.8 {d28}, [r3:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d12[0]}, [%[source]]\n"
|
|
"vld1.32 {d16[0]}, [r0]\n"
|
|
"vld1.32 {d20[0]}, [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.8 {d24[0]}, [%[destination]]\n"
|
|
"vst1.8 {d26[0]}, [r1]\n"
|
|
"vst1.8 {d28[0]}, [r3]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
|
|
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
|
|
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
|
|
"d28", "d29", "cc", "memory");
|
|
}
|
|
|
|
void qnt_3x8_2_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"add r2, r0, %[stride]\n"
|
|
"add r3, r1, %[destination_stride]\n"
|
|
"subs %[count], %[count], #2\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
|
|
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
|
|
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"pld [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.8 {d24}, [%[destination]:64]!\n"
|
|
"vst1.8 {d26}, [r1:64]!\n"
|
|
"vst1.8 {d28}, [r3:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d12}, [%[source]:64]\n"
|
|
"vld1.32 {d16}, [r0:64]\n"
|
|
"vld1.32 {d20}, [r2:64]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.16 {d24[0]}, [%[destination]]\n"
|
|
"vst1.16 {d26[0]}, [r1]\n"
|
|
"vst1.16 {d28[0]}, [r3]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
|
|
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
|
|
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
|
|
"d28", "d29", "cc", "memory");
|
|
}
|
|
|
|
void qnt_3x8_3_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"add r2, r0, %[stride]\n"
|
|
"add r3, r1, %[destination_stride]\n"
|
|
"subs %[count], %[count], #3\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
|
|
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
|
|
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"pld [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.8 {d24}, [%[destination]:64]!\n"
|
|
"vst1.8 {d26}, [r1:64]!\n"
|
|
"vst1.8 {d28}, [r3:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d12}, [%[source]:64]!\n"
|
|
"vld1.32 {d16}, [r0:64]!\n"
|
|
"vld1.32 {d20}, [r2:64]!\n"
|
|
"vld1.32 {d13[0]}, [%[source]]\n"
|
|
"vld1.32 {d17[0]}, [r0]\n"
|
|
"vld1.32 {d21[0]}, [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.16 {d24[0]}, [%[destination]]!\n"
|
|
"vst1.16 {d26[0]}, [r1]!\n"
|
|
"vst1.16 {d28[0]}, [r3]!\n"
|
|
"vst1.8 {d24[2]}, [%[destination]]\n"
|
|
"vst1.8 {d26[2]}, [r1]\n"
|
|
"vst1.8 {d28[2]}, [r3]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
|
|
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
|
|
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
|
|
"d28", "d29", "cc", "memory");
|
|
}
|
|
|
|
void qnt_3x8_4_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"add r2, r0, %[stride]\n"
|
|
"add r3, r1, %[destination_stride]\n"
|
|
"subs %[count], %[count], #4\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
|
|
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
|
|
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"pld [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.8 {d24}, [%[destination]:64]!\n"
|
|
"vst1.8 {d26}, [r1:64]!\n"
|
|
"vst1.8 {d28}, [r3:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d12, d13}, [%[source]:64]\n"
|
|
"vld1.32 {d16, d17}, [r0:64]\n"
|
|
"vld1.32 {d20, d21}, [r2:64]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.32 {d24[0]}, [%[destination]]\n"
|
|
"vst1.32 {d26[0]}, [r1]\n"
|
|
"vst1.32 {d28[0]}, [r3]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
|
|
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
|
|
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
|
|
"d28", "d29", "cc", "memory");
|
|
}
|
|
|
|
void qnt_3x8_5_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"add r2, r0, %[stride]\n"
|
|
"add r3, r1, %[destination_stride]\n"
|
|
"subs %[count], %[count], #5\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
|
|
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
|
|
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"pld [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.8 {d24}, [%[destination]:64]!\n"
|
|
"vst1.8 {d26}, [r1:64]!\n"
|
|
"vst1.8 {d28}, [r3:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d12, d13}, [%[source]:64]!\n"
|
|
"vld1.32 {d16, d17}, [r0:64]!\n"
|
|
"vld1.32 {d20, d21}, [r2:64]!\n"
|
|
"vld1.32 {d14[0]}, [%[source]]\n"
|
|
"vld1.32 {d18[0]}, [r0]\n"
|
|
"vld1.32 {d22[0]}, [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.32 {d24[0]}, [%[destination]]!\n"
|
|
"vst1.32 {d26[0]}, [r1]!\n"
|
|
"vst1.32 {d28[0]}, [r3]!\n"
|
|
"vst1.8 {d24[4]}, [%[destination]]\n"
|
|
"vst1.8 {d26[4]}, [r1]\n"
|
|
"vst1.8 {d28[4]}, [r3]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
|
|
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
|
|
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
|
|
"d28", "d29", "cc", "memory");
|
|
}
|
|
|
|
void qnt_3x8_6_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"add r2, r0, %[stride]\n"
|
|
"add r3, r1, %[destination_stride]\n"
|
|
"subs %[count], %[count], #6\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
|
|
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
|
|
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"pld [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.8 {d24}, [%[destination]:64]!\n"
|
|
"vst1.8 {d26}, [r1:64]!\n"
|
|
"vst1.8 {d28}, [r3:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d12, d13, d14}, [%[source]:64]\n"
|
|
"vld1.32 {d16, d17, d18}, [r0:64]\n"
|
|
"vld1.32 {d20, d21, d22}, [r2:64]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.32 {d24[0]}, [%[destination]]!\n"
|
|
"vst1.32 {d26[0]}, [r1]!\n"
|
|
"vst1.32 {d28[0]}, [r3]!\n"
|
|
"vst1.16 {d24[2]}, [%[destination]]\n"
|
|
"vst1.16 {d26[2]}, [r1]\n"
|
|
"vst1.16 {d28[2]}, [r3]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
|
|
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
|
|
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
|
|
"d28", "d29", "cc", "memory");
|
|
}
|
|
|
|
void qnt_3x8_7_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"add r2, r0, %[stride]\n"
|
|
"add r3, r1, %[destination_stride]\n"
|
|
"subs %[count], %[count], #7\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
|
|
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
|
|
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"pld [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.8 {d24}, [%[destination]:64]!\n"
|
|
"vst1.8 {d26}, [r1:64]!\n"
|
|
"vst1.8 {d28}, [r3:64]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d12, d13, d14}, [%[source]:64]!\n"
|
|
"vld1.32 {d16, d17, d18}, [r0:64]!\n"
|
|
"vld1.32 {d20, d21, d22}, [r2:64]!\n"
|
|
"vld1.32 {d15[0]}, [%[source]]\n"
|
|
"vld1.32 {d19[0]}, [r0]\n"
|
|
"vld1.32 {d23[0]}, [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.32 {d24[0]}, [%[destination]]!\n"
|
|
"vst1.32 {d26[0]}, [r1]!\n"
|
|
"vst1.32 {d28[0]}, [r3]!\n"
|
|
"vst1.16 {d24[2]}, [%[destination]]!\n"
|
|
"vst1.16 {d26[2]}, [r1]!\n"
|
|
"vst1.16 {d28[2]}, [r3]!\n"
|
|
"vst1.8 {d24[6]}, [%[destination]]!\n"
|
|
"vst1.8 {d26[6]}, [r1]!\n"
|
|
"vst1.8 {d28[6]}, [r3]!\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
|
|
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
|
|
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
|
|
"d28", "d29", "cc", "memory");
|
|
}
|
|
|
|
void qnt_1x8(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.8 {d12}, [%[destination]]!\n"
|
|
|
|
"bne 1b\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void qnt_1x8_1(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"subs %[count], %[count], #1\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.8 {d12}, [%[destination]]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d8[0]}, [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.8 {d12[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void qnt_1x8_2(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"subs %[count], %[count], #2\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.8 {d12}, [%[destination]]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d8}, [%[source]:64]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.16 {d12[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void qnt_1x8_3(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"subs %[count], %[count], #3\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.8 {d12}, [%[destination]]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d8}, [%[source]:64]!\n"
|
|
"vld1.32 {d9[0]}, [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.16 {d12[0]}, [%[destination]]!\n"
|
|
"vst1.8 {d12[2]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void qnt_1x8_4(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"subs %[count], %[count], #4\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.8 {d12}, [%[destination]]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d8, d9}, [%[source]:64]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.32 {d12[0]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void qnt_1x8_5(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"subs %[count], %[count], #5\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.8 {d12}, [%[destination]]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d8, d9}, [%[source]:64]!\n"
|
|
"vld1.32 {d10[0]}, [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.32 {d12[0]}, [%[destination]]!\n"
|
|
"vst1.8 {d12[4]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void qnt_1x8_6(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"subs %[count], %[count], #6\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.8 {d12}, [%[destination]]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d8, d9, d10}, [%[source]:64]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.32 {d12[0]}, [%[destination]]!\n"
|
|
"vst1.16 {d12[2]}, [%[destination]]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void qnt_1x8_7(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"subs %[count], %[count], #7\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d8, d9, d10, d11}, [%[source]:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.8 {d12}, [%[destination]]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d8, d9, d10}, [%[source]:64]!\n"
|
|
"vld1.32 {d11[0]}, [%[source]]\n"
|
|
"vadd.i32 q4, q4, q3\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vmul.i32 q4, q4, q0\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vadd.i32 q4, q4, q1\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vshl.s32 q4, q4, q2\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vqmovn.s32 d12, q4\n"
|
|
"vqmovn.s32 d13, q5\n"
|
|
"vqmovun.s16 d12, q6\n"
|
|
"vst1.32 {d12[0]}, [%[destination]]!\n"
|
|
"vst1.16 {d12[2]}, [%[destination]]!\n"
|
|
"vst1.8 {d12[6]}, [%[destination]]!\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
|
|
"d11", "d12", "d13", "cc", "memory");
|
|
}
|
|
|
|
void qnt_2x8(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
|
|
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.8 {d18}, [%[destination]]!\n"
|
|
"vst1.8 {d20}, [r1]!\n"
|
|
|
|
"bne 1b\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
|
|
"d20", "d21", "cc", "memory");
|
|
}
|
|
|
|
void qnt_2x8_1(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"subs %[count], %[count], #1\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
|
|
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.8 {d18}, [%[destination]]!\n"
|
|
"vst1.8 {d20}, [r1]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d10[0]}, [%[source]]\n"
|
|
"vld1.32 {d14[0]}, [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.8 {d18[0]}, [%[destination]]\n"
|
|
"vst1.8 {d20[0]}, [r1]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
|
|
"d20", "d21", "cc", "memory");
|
|
}
|
|
|
|
void qnt_2x8_2(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"subs %[count], %[count], #2\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
|
|
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.8 {d18}, [%[destination]]!\n"
|
|
"vst1.8 {d20}, [r1]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d10}, [%[source]:64]\n"
|
|
"vld1.32 {d14}, [r0:64]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.16 {d18[0]}, [%[destination]]\n"
|
|
"vst1.16 {d20[0]}, [r1]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
|
|
"d20", "d21", "cc", "memory");
|
|
}
|
|
|
|
void qnt_2x8_3(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"subs %[count], %[count], #3\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
|
|
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.8 {d18}, [%[destination]]!\n"
|
|
"vst1.8 {d20}, [r1]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d10}, [%[source]:64]!\n"
|
|
"vld1.32 {d14}, [r0:64]!\n"
|
|
"vld1.32 {d11[0]}, [%[source]]\n"
|
|
"vld1.32 {d15[0]}, [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.16 {d18[0]}, [%[destination]]!\n"
|
|
"vst1.16 {d20[0]}, [r1]!\n"
|
|
"vst1.8 {d18[2]}, [%[destination]]\n"
|
|
"vst1.8 {d20[2]}, [r1]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
|
|
"d20", "d21", "cc", "memory");
|
|
}
|
|
|
|
void qnt_2x8_4(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"subs %[count], %[count], #4\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
|
|
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.8 {d18}, [%[destination]]!\n"
|
|
"vst1.8 {d20}, [r1]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d10, d11}, [%[source]:64]\n"
|
|
"vld1.32 {d14, d15}, [r0:64]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.32 {d18[0]}, [%[destination]]\n"
|
|
"vst1.32 {d20[0]}, [r1]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
|
|
"d20", "d21", "cc", "memory");
|
|
}
|
|
|
|
void qnt_2x8_5(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"subs %[count], %[count], #5\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
|
|
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.8 {d18}, [%[destination]]!\n"
|
|
"vst1.8 {d20}, [r1]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d10, d11}, [%[source]:64]!\n"
|
|
"vld1.32 {d14, d15}, [r0:64]!\n"
|
|
"vld1.32 {d12[0]}, [%[source]]\n"
|
|
"vld1.32 {d16[0]}, [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.32 {d18[0]}, [%[destination]]!\n"
|
|
"vst1.32 {d20[0]}, [r1]!\n"
|
|
"vst1.8 {d18[4]}, [%[destination]]\n"
|
|
"vst1.8 {d20[4]}, [r1]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
|
|
"d20", "d21", "cc", "memory");
|
|
}
|
|
|
|
void qnt_2x8_6(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"subs %[count], %[count], #6\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
|
|
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.8 {d18}, [%[destination]]!\n"
|
|
"vst1.8 {d20}, [r1]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d10, d11, d12}, [%[source]:64]\n"
|
|
"vld1.32 {d14, d15, d16}, [r0:64]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.32 {d18[0]}, [%[destination]]!\n"
|
|
"vst1.32 {d20[0]}, [r1]!\n"
|
|
"vst1.16 {d18[2]}, [%[destination]]\n"
|
|
"vst1.16 {d20[2]}, [r1]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
|
|
"d20", "d21", "cc", "memory");
|
|
}
|
|
|
|
void qnt_2x8_7(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"subs %[count], %[count], #7\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d10, d11, d12, d13}, [%[source]:64]!\n"
|
|
"vld1.32 {d14, d15, d16, d17}, [r0:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.8 {d18}, [%[destination]]!\n"
|
|
"vst1.8 {d20}, [r1]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d10, d11, d12}, [%[source]:64]!\n"
|
|
"vld1.32 {d14, d15, d16}, [r0:64]!\n"
|
|
"vld1.32 {d13[0]}, [%[source]]\n"
|
|
"vld1.32 {d17[0]}, [r0]\n"
|
|
"vadd.i32 q5, q5, q3\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q4\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vmul.i32 q5, q5, q0\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vadd.i32 q5, q5, q1\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vshl.s32 q5, q5, q2\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vqmovn.s32 d18, q5\n"
|
|
"vqmovn.s32 d19, q6\n"
|
|
"vqmovn.s32 d20, q7\n"
|
|
"vqmovn.s32 d21, q8\n"
|
|
"vqmovun.s16 d18, q9\n"
|
|
"vqmovun.s16 d20, q10\n"
|
|
"vst1.32 {d18[0]}, [%[destination]]!\n"
|
|
"vst1.32 {d20[0]}, [r1]!\n"
|
|
"vst1.16 {d18[2]}, [%[destination]]!\n"
|
|
"vst1.16 {d20[2]}, [r1]!\n"
|
|
"vst1.8 {d18[6]}, [%[destination]]!\n"
|
|
"vst1.8 {d20[6]}, [r1]!\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9",
|
|
"d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19",
|
|
"d20", "d21", "cc", "memory");
|
|
}
|
|
|
|
void qnt_3x8(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"add r2, r0, %[stride]\n"
|
|
"add r3, r1, %[destination_stride]\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
|
|
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
|
|
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"pld [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.8 {d24}, [%[destination]]!\n"
|
|
"vst1.8 {d26}, [r1]!\n"
|
|
"vst1.8 {d28}, [r3]!\n"
|
|
|
|
"bne 1b\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
|
|
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
|
|
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
|
|
"d28", "d29", "cc", "memory");
|
|
}
|
|
|
|
void qnt_3x8_1(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"add r2, r0, %[stride]\n"
|
|
"add r3, r1, %[destination_stride]\n"
|
|
"subs %[count], %[count], #1\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
|
|
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
|
|
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"pld [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.8 {d24}, [%[destination]]!\n"
|
|
"vst1.8 {d26}, [r1]!\n"
|
|
"vst1.8 {d28}, [r3]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d12[0]}, [%[source]]\n"
|
|
"vld1.32 {d16[0]}, [r0]\n"
|
|
"vld1.32 {d20[0]}, [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.8 {d24[0]}, [%[destination]]\n"
|
|
"vst1.8 {d26[0]}, [r1]\n"
|
|
"vst1.8 {d28[0]}, [r3]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
|
|
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
|
|
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
|
|
"d28", "d29", "cc", "memory");
|
|
}
|
|
|
|
void qnt_3x8_2(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"add r2, r0, %[stride]\n"
|
|
"add r3, r1, %[destination_stride]\n"
|
|
"subs %[count], %[count], #2\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
|
|
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
|
|
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"pld [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.8 {d24}, [%[destination]]!\n"
|
|
"vst1.8 {d26}, [r1]!\n"
|
|
"vst1.8 {d28}, [r3]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d12}, [%[source]:64]\n"
|
|
"vld1.32 {d16}, [r0:64]\n"
|
|
"vld1.32 {d20}, [r2:64]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.16 {d24[0]}, [%[destination]]\n"
|
|
"vst1.16 {d26[0]}, [r1]\n"
|
|
"vst1.16 {d28[0]}, [r3]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
|
|
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
|
|
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
|
|
"d28", "d29", "cc", "memory");
|
|
}
|
|
|
|
void qnt_3x8_3(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"add r2, r0, %[stride]\n"
|
|
"add r3, r1, %[destination_stride]\n"
|
|
"subs %[count], %[count], #3\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
|
|
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
|
|
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"pld [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.8 {d24}, [%[destination]]!\n"
|
|
"vst1.8 {d26}, [r1]!\n"
|
|
"vst1.8 {d28}, [r3]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d12}, [%[source]:64]!\n"
|
|
"vld1.32 {d16}, [r0:64]!\n"
|
|
"vld1.32 {d20}, [r2:64]!\n"
|
|
"vld1.32 {d13[0]}, [%[source]]\n"
|
|
"vld1.32 {d17[0]}, [r0]\n"
|
|
"vld1.32 {d21[0]}, [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.16 {d24[0]}, [%[destination]]!\n"
|
|
"vst1.16 {d26[0]}, [r1]!\n"
|
|
"vst1.16 {d28[0]}, [r3]!\n"
|
|
"vst1.8 {d24[2]}, [%[destination]]\n"
|
|
"vst1.8 {d26[2]}, [r1]\n"
|
|
"vst1.8 {d28[2]}, [r3]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
|
|
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
|
|
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
|
|
"d28", "d29", "cc", "memory");
|
|
}
|
|
|
|
void qnt_3x8_4(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"add r2, r0, %[stride]\n"
|
|
"add r3, r1, %[destination_stride]\n"
|
|
"subs %[count], %[count], #4\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
|
|
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
|
|
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"pld [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.8 {d24}, [%[destination]]!\n"
|
|
"vst1.8 {d26}, [r1]!\n"
|
|
"vst1.8 {d28}, [r3]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d12, d13}, [%[source]:64]\n"
|
|
"vld1.32 {d16, d17}, [r0:64]\n"
|
|
"vld1.32 {d20, d21}, [r2:64]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.32 {d24[0]}, [%[destination]]\n"
|
|
"vst1.32 {d26[0]}, [r1]\n"
|
|
"vst1.32 {d28[0]}, [r3]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
|
|
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
|
|
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
|
|
"d28", "d29", "cc", "memory");
|
|
}
|
|
|
|
void qnt_3x8_5(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"add r2, r0, %[stride]\n"
|
|
"add r3, r1, %[destination_stride]\n"
|
|
"subs %[count], %[count], #5\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
|
|
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
|
|
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"pld [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.8 {d24}, [%[destination]]!\n"
|
|
"vst1.8 {d26}, [r1]!\n"
|
|
"vst1.8 {d28}, [r3]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d12, d13}, [%[source]:64]!\n"
|
|
"vld1.32 {d16, d17}, [r0:64]!\n"
|
|
"vld1.32 {d20, d21}, [r2:64]!\n"
|
|
"vld1.32 {d14[0]}, [%[source]]\n"
|
|
"vld1.32 {d18[0]}, [r0]\n"
|
|
"vld1.32 {d22[0]}, [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.32 {d24[0]}, [%[destination]]!\n"
|
|
"vst1.32 {d26[0]}, [r1]!\n"
|
|
"vst1.32 {d28[0]}, [r3]!\n"
|
|
"vst1.8 {d24[4]}, [%[destination]]\n"
|
|
"vst1.8 {d26[4]}, [r1]\n"
|
|
"vst1.8 {d28[4]}, [r3]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
|
|
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
|
|
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
|
|
"d28", "d29", "cc", "memory");
|
|
}
|
|
|
|
void qnt_3x8_6(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"add r2, r0, %[stride]\n"
|
|
"add r3, r1, %[destination_stride]\n"
|
|
"subs %[count], %[count], #6\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
|
|
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
|
|
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"pld [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.8 {d24}, [%[destination]]!\n"
|
|
"vst1.8 {d26}, [r1]!\n"
|
|
"vst1.8 {d28}, [r3]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d12, d13, d14}, [%[source]:64]\n"
|
|
"vld1.32 {d16, d17, d18}, [r0:64]\n"
|
|
"vld1.32 {d20, d21, d22}, [r2:64]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.32 {d24[0]}, [%[destination]]!\n"
|
|
"vst1.32 {d26[0]}, [r1]!\n"
|
|
"vst1.32 {d28[0]}, [r3]!\n"
|
|
"vst1.16 {d24[2]}, [%[destination]]\n"
|
|
"vst1.16 {d26[2]}, [r1]\n"
|
|
"vst1.16 {d28[2]}, [r3]\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
|
|
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
|
|
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
|
|
"d28", "d29", "cc", "memory");
|
|
}
|
|
|
|
void qnt_3x8_7(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset, std::int32_t rounding_offset,
|
|
std::int32_t shift) {
|
|
asm volatile(
|
|
"vdup.32 q0, %[multiplicative_offset]\n"
|
|
"vdup.32 q1, %[rounding_offset]\n"
|
|
"vdup.32 q2, %[shift]\n"
|
|
"vld1.32 {d6[], d7[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d8[], d9[]}, [%[offsets]:32]!\n"
|
|
"vld1.32 {d10[], d11[]}, [%[offsets]:32]!\n"
|
|
"add r0, %[source], %[stride]\n"
|
|
"add r1, %[destination], %[destination_stride]\n"
|
|
"add r2, r0, %[stride]\n"
|
|
"add r3, r1, %[destination_stride]\n"
|
|
"subs %[count], %[count], #7\n"
|
|
"beq 2f\n"
|
|
|
|
"1:"
|
|
"subs %[count], %[count], #8\n"
|
|
"vld1.32 {d12, d13, d14, d15}, [%[source]:64]!\n"
|
|
"vld1.32 {d16, d17, d18, d19}, [r0:64]!\n"
|
|
"vld1.32 {d20, d21, d22, d23}, [r2:64]!\n"
|
|
"pld [%[source]]\n"
|
|
"pld [r0]\n"
|
|
"pld [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.8 {d24}, [%[destination]]!\n"
|
|
"vst1.8 {d26}, [r1]!\n"
|
|
"vst1.8 {d28}, [r3]!\n"
|
|
|
|
"bne 1b\n"
|
|
"2:"
|
|
"vld1.32 {d12, d13, d14}, [%[source]:64]!\n"
|
|
"vld1.32 {d16, d17, d18}, [r0:64]!\n"
|
|
"vld1.32 {d20, d21, d22}, [r2:64]!\n"
|
|
"vld1.32 {d15[0]}, [%[source]]\n"
|
|
"vld1.32 {d19[0]}, [r0]\n"
|
|
"vld1.32 {d23[0]}, [r2]\n"
|
|
"vadd.i32 q6, q6, q3\n"
|
|
"vadd.i32 q7, q7, q3\n"
|
|
"vadd.i32 q8, q8, q4\n"
|
|
"vadd.i32 q9, q9, q4\n"
|
|
"vadd.i32 q10, q10, q5\n"
|
|
"vadd.i32 q11, q11, q5\n"
|
|
"vmul.i32 q6, q6, q0\n"
|
|
"vmul.i32 q7, q7, q0\n"
|
|
"vmul.i32 q8, q8, q0\n"
|
|
"vmul.i32 q9, q9, q0\n"
|
|
"vmul.i32 q10, q10, q0\n"
|
|
"vmul.i32 q11, q11, q0\n"
|
|
"vadd.i32 q6, q6, q1\n"
|
|
"vadd.i32 q7, q7, q1\n"
|
|
"vadd.i32 q8, q8, q1\n"
|
|
"vadd.i32 q9, q9, q1\n"
|
|
"vadd.i32 q10, q10, q1\n"
|
|
"vadd.i32 q11, q11, q1\n"
|
|
"vshl.s32 q6, q6, q2\n"
|
|
"vshl.s32 q7, q7, q2\n"
|
|
"vshl.s32 q8, q8, q2\n"
|
|
"vshl.s32 q9, q9, q2\n"
|
|
"vshl.s32 q10, q10, q2\n"
|
|
"vshl.s32 q11, q11, q2\n"
|
|
"vqmovn.s32 d24, q6\n"
|
|
"vqmovn.s32 d25, q7\n"
|
|
"vqmovn.s32 d26, q8\n"
|
|
"vqmovn.s32 d27, q9\n"
|
|
"vqmovn.s32 d28, q10\n"
|
|
"vqmovn.s32 d29, q11\n"
|
|
"vqmovun.s16 d24, q12\n"
|
|
"vqmovun.s16 d26, q13\n"
|
|
"vqmovun.s16 d28, q14\n"
|
|
"vst1.32 {d24[0]}, [%[destination]]!\n"
|
|
"vst1.32 {d26[0]}, [r1]!\n"
|
|
"vst1.32 {d28[0]}, [r3]!\n"
|
|
"vst1.16 {d24[2]}, [%[destination]]!\n"
|
|
"vst1.16 {d26[2]}, [r1]!\n"
|
|
"vst1.16 {d28[2]}, [r3]!\n"
|
|
"vst1.8 {d24[6]}, [%[destination]]!\n"
|
|
"vst1.8 {d26[6]}, [r1]!\n"
|
|
"vst1.8 {d28[6]}, [r3]!\n"
|
|
: [count] "+r"(count),
|
|
[multiplicative_offset] "+r"(multiplicative_offset),
|
|
[stride] "+r"(stride), [shift] "+r"(shift),
|
|
[destination] "+r"(destination), [offsets] "+r"(offsets),
|
|
[source] "+r"(source), [destination_stride] "+r"(destination_stride),
|
|
[rounding_offset] "+r"(rounding_offset)
|
|
:
|
|
: "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
|
|
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17",
|
|
"d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27",
|
|
"d28", "d29", "cc", "memory");
|
|
}
|
|
|
|
void multi_qnt_1x8_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
switch (count % 8) {
|
|
case 0:
|
|
qnt_1x8_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
case 1:
|
|
qnt_1x8_1_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
case 2:
|
|
qnt_1x8_2_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
case 3:
|
|
qnt_1x8_3_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
case 4:
|
|
qnt_1x8_4_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
case 5:
|
|
qnt_1x8_5_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
case 6:
|
|
qnt_1x8_6_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
case 7:
|
|
qnt_1x8_7_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
}
|
|
}
|
|
|
|
void multi_qnt_2x8_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
switch (count % 8) {
|
|
case 0:
|
|
qnt_2x8_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
case 1:
|
|
qnt_2x8_1_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
case 2:
|
|
qnt_2x8_2_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
case 3:
|
|
qnt_2x8_3_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
case 4:
|
|
qnt_2x8_4_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
case 5:
|
|
qnt_2x8_5_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
case 6:
|
|
qnt_2x8_6_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
case 7:
|
|
qnt_2x8_7_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
}
|
|
}
|
|
|
|
void multi_qnt_3x8_aligned(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination,
|
|
std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
switch (count % 8) {
|
|
case 0:
|
|
qnt_3x8_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
case 1:
|
|
qnt_3x8_1_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
case 2:
|
|
qnt_3x8_2_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
case 3:
|
|
qnt_3x8_3_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
case 4:
|
|
qnt_3x8_4_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
case 5:
|
|
qnt_3x8_5_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
case 6:
|
|
qnt_3x8_6_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
case 7:
|
|
qnt_3x8_7_aligned(source, count, stride, offsets, destination,
|
|
destination_stride, multiplicative_offset,
|
|
rounding_offset, shift);
|
|
break;
|
|
}
|
|
}
|
|
|
|
void multi_qnt_1x8(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
switch (count % 8) {
|
|
case 0:
|
|
qnt_1x8(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
case 1:
|
|
qnt_1x8_1(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
case 2:
|
|
qnt_1x8_2(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
case 3:
|
|
qnt_1x8_3(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
case 4:
|
|
qnt_1x8_4(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
case 5:
|
|
qnt_1x8_5(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
case 6:
|
|
qnt_1x8_6(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
case 7:
|
|
qnt_1x8_7(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
}
|
|
}
|
|
|
|
void multi_qnt_2x8(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
switch (count % 8) {
|
|
case 0:
|
|
qnt_2x8(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
case 1:
|
|
qnt_2x8_1(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
case 2:
|
|
qnt_2x8_2(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
case 3:
|
|
qnt_2x8_3(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
case 4:
|
|
qnt_2x8_4(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
case 5:
|
|
qnt_2x8_5(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
case 6:
|
|
qnt_2x8_6(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
case 7:
|
|
qnt_2x8_7(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
}
|
|
}
|
|
|
|
void multi_qnt_3x8(const std::int32_t* source, std::int32_t count,
|
|
std::int32_t stride, const std::int32_t* offsets,
|
|
std::uint8_t* destination, std::int32_t destination_stride,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t rounding_offset, std::int32_t shift) {
|
|
switch (count % 8) {
|
|
case 0:
|
|
qnt_3x8(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
case 1:
|
|
qnt_3x8_1(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
case 2:
|
|
qnt_3x8_2(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
case 3:
|
|
qnt_3x8_3(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
case 4:
|
|
qnt_3x8_4(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
case 5:
|
|
qnt_3x8_5(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
case 6:
|
|
qnt_3x8_6(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
case 7:
|
|
qnt_3x8_7(source, count, stride, offsets, destination, destination_stride,
|
|
multiplicative_offset, rounding_offset, shift);
|
|
break;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_0_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_0_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_0_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_0_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_0_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_0_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_0_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_0_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_1_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_1_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_1_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_1_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_1_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_1_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_1_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_1_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_2_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_2_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_2_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_2_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_2_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_2_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_2_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_2_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_1_0_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_0_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_0_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_0_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_0_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_0_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_0_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_0_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_1_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_1_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_1_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_1_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_1_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_1_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_1_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_1_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_2_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_2_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_2_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_2_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_2_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_2_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_2_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_2_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_0_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_0_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_0_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_0_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_0_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_0_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_0_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_0_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_1_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_1_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_1_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_1_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_1_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_1_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_1_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_1_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_2_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_2_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_2_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_2_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_2_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_2_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_2_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_2_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8_aligned(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_0_0_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_0_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_0_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_0_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_0_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_0_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_0_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_0_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_1_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_1_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_1_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_1_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_1_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_1_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_1_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_1_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_2_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_2_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_2_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_2_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_2_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_2_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_2_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_0_2_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_q8_1_0_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_0_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_0_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_0_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_0_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_0_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_0_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_0_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_1_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_1_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_1_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_1_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_1_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_1_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_1_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_1_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_2_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_2_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_2_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_2_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_2_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_2_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_2_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_1_2_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_1x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_1_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_0_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_0_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_0_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_0_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_0_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_0_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_0_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_0_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_1_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_1_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_1_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_1_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_1_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_1_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_1_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_1_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_2_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_2_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_2_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_2_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_2_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_2_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_2_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_q8_2_2_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k + result_offset;
|
|
const std::int32_t rounding_offset = (1 << (shift - 1));
|
|
std::int32_t* temp_result = reinterpret_cast<std::int32_t*>(
|
|
scratch + zipped_chunk_size + zipped_rhs_size);
|
|
std::uint8_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = temp_result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = ((n * 4 + 7) / 8) * 8;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_3x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_3_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = temp_result;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk, mul_result_chunk_stride_bytes);
|
|
multi_qnt_2x8(temp_result, n, mul_result_chunk_stride_bytes,
|
|
zipped_lhs_2_offsets, result_chunk, result_stride,
|
|
multiplicative_offset, rounding_offset, -shift);
|
|
}
|
|
|
|
void gemm_i32_0_0_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_0_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_0_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_0_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_0_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_0_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_0_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_0_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_1_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_1_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_1_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_1_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_1_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_1_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_1_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_1_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_2_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_2_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_2_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_2_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_2_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_2_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_2_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_2_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_1_0_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_1_0_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_1_0_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_1_0_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_1_0_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_1_0_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_1_0_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_1_0_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_1_1_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_1_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_1_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_1_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_1_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_1_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_1_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_1_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_2_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_2_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_2_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_2_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_2_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_2_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_2_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_2_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_0_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_2_0_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_2_0_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_2_0_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_2_0_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_2_0_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_2_0_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_2_0_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_2_1_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_1_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_1_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_1_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_1_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_1_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_1_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_1_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_2_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_2_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_2_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_2_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_2_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_2_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_2_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_2_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result, std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_0_0_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_0_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_0_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_0_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_0_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_0_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_0_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_0_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_1_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_1_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_1_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_1_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_1_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_1_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_1_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_1_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_2_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_2_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_2_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_2_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_2_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_2_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_2_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_0_2_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_1_0_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_1_0_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_1_0_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_1_0_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_1_0_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_1_0_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_1_0_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_1_0_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_1_1_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_1_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_1_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_1_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_1_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_1_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_1_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_1_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_2_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_2_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_2_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_2_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_2_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_2_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_2_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_1_2_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_0_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_2_0_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_2_0_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_2_0_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_2_0_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_2_0_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_2_0_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_2_0_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_i32_2_1_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_1_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_1_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_1_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_1_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_1_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_1_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_1_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_2_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_2_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_2_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_2_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_2_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_2_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_2_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_i32_2_2_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
std::int32_t* result_chunk = result;
|
|
std::int32_t* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_int32_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes);
|
|
}
|
|
|
|
void gemm_f_0_0_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_0_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_0_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_0_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_0_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_0_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_0_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_0_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_1_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_1_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_1_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_1_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_1_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_1_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_1_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_1_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_2_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_2_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_2_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_2_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_2_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_2_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_2_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_2_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_1_0_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_1_0_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_1_0_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_1_0_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_1_0_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_1_0_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_1_0_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_1_0_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_1_1_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_1_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_1_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_1_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_1_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_1_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_1_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_1_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_2_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_2_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_2_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_2_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_2_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_2_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_2_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_2_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_0_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_2_0_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_2_0_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_2_0_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_2_0_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_2_0_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_2_0_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_2_0_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_2_1_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_1_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_1_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_1_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_1_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_1_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_1_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_1_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_2_0_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_2_1_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_1_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_1_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_2_2_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_2_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_2_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_2_3_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_3_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_3_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_2_4_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_4_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_4_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_2_5_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_5_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_5_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_2_6_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_6_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_6_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_2_7_aligned(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m,
|
|
std::int32_t n, std::int32_t k,
|
|
std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_7_aligned(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_7_aligned(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_0_0_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_0_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_0_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_0_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_0_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_0_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_0_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_0_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_1_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_1_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_1_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_1_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_1_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_1_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_1_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_1_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_2_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_2_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_2_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_2_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_2_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_2_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_2_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_0_2_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
}
|
|
|
|
void gemm_f_1_0_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_1_0_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_1_0_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_1_0_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_1_0_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_1_0_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_1_0_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_1_0_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_1_1_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_1_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_1_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_1_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_1_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_1_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_1_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_1_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_2_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_2_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_2_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_2_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_2_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_2_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_2_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_1_2_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_1_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 1);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_1x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_1x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_1x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_0_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_2_0_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_2_0_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_2_0_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_2_0_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_2_0_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_2_0_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_2_0_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
}
|
|
|
|
void gemm_f_2_1_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_1_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_1_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_1_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_1_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_1_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_1_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_1_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_1x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_1x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_1x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_2_0(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_2_1(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_1(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_1(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_2_2(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_2(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_2(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_2_3(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_3(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_3(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_2_4(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_4(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_4(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_2_5(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_5(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_5(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_2_6(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_6(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_6(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
void gemm_f_2_2_7(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const std::int32_t row_chunks = m / 3;
|
|
const std::int32_t col_chunks = n / 3;
|
|
const std::int32_t padded_k = ((k + 7) / 8) * 8;
|
|
const std::int32_t chunk_size = k * 3;
|
|
const std::int32_t zipped_chunk_size = (padded_k + 16) * 3;
|
|
const std::int32_t zipped_rhs_size = (padded_k + 16) * n;
|
|
const std::uint8_t* lhs_chunk = lhs;
|
|
const std::uint8_t* rhs_chunk = rhs;
|
|
std::uint8_t* zipped_lhs = scratch;
|
|
std::int32_t* zipped_lhs_3_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3);
|
|
std::int32_t* zipped_lhs_2_offsets =
|
|
reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 2);
|
|
std::uint8_t* zipped_rhs = scratch + zipped_chunk_size;
|
|
std::uint8_t* zipped_rhs_chunk = zipped_rhs;
|
|
const std::int32_t result_chunk_stride = result_stride * 3;
|
|
|
|
const std::int32_t const_offset = lhs_offset * rhs_offset * k;
|
|
float* result_chunk = result;
|
|
float* mul_result_chunk = result;
|
|
const std::int32_t mul_result_chunk_stride_bytes = result_stride * 4;
|
|
|
|
for (int i = 0; i < col_chunks; ++i) {
|
|
zip_3x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
rhs_chunk += chunk_size;
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
}
|
|
zip_2x8_7(rhs_chunk, k, k, zipped_rhs_chunk, lhs_offset, 0);
|
|
|
|
for (int i = 0; i < row_chunks; ++i) {
|
|
zip_3x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_3x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_3x8_2x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
lhs_chunk += chunk_size;
|
|
result_chunk += result_chunk_stride;
|
|
}
|
|
|
|
zip_2x8_7(lhs_chunk, k, k, zipped_lhs, rhs_offset, const_offset);
|
|
zipped_rhs_chunk = zipped_rhs;
|
|
mul_result_chunk = result_chunk;
|
|
for (int j = 0; j < col_chunks; ++j) {
|
|
mul_2x8_3x8_float_lhsadd_rhsadd(
|
|
zipped_lhs, zipped_rhs_chunk, padded_k, mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
zipped_rhs_chunk += zipped_chunk_size;
|
|
mul_result_chunk += 3;
|
|
}
|
|
mul_2x8_2x8_float_lhsadd_rhsadd(zipped_lhs, zipped_rhs_chunk, padded_k,
|
|
mul_result_chunk,
|
|
mul_result_chunk_stride_bytes, result_scale);
|
|
}
|
|
|
|
} // namespace internal
|
|
|
|
void gemm_q8_strided(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t result_offset,
|
|
std::int32_t multiplicative_offset, std::int32_t shift,
|
|
std::uint8_t* result, std::int32_t result_stride) {
|
|
const bool lhs_aligned = ((reinterpret_cast<std::uintptr_t>(lhs) % 8) == 0);
|
|
const bool rhs_aligned = ((reinterpret_cast<std::uintptr_t>(rhs) % 8) == 0);
|
|
const bool k_aligned = ((k % 8) == 0);
|
|
const bool result_aligned =
|
|
((reinterpret_cast<std::uintptr_t>(result) % 8) == 0);
|
|
const bool result_stride_aligned = ((result_stride % 8) == 0);
|
|
const bool aligned = lhs_aligned && rhs_aligned && result_aligned &&
|
|
k_aligned && result_stride_aligned;
|
|
if (aligned) {
|
|
switch (m % 3) {
|
|
case 0:
|
|
switch (n % 3) {
|
|
case 0:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_q8_0_0_0_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_q8_0_0_1_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_q8_0_0_2_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_q8_0_0_3_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_q8_0_0_4_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_q8_0_0_5_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_q8_0_0_6_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_q8_0_0_7_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_q8_0_1_0_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_q8_0_1_1_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_q8_0_1_2_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_q8_0_1_3_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_q8_0_1_4_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_q8_0_1_5_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_q8_0_1_6_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_q8_0_1_7_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_q8_0_2_0_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_q8_0_2_1_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_q8_0_2_2_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_q8_0_2_3_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_q8_0_2_4_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_q8_0_2_5_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_q8_0_2_6_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_q8_0_2_7_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (n % 3) {
|
|
case 0:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_q8_1_0_0_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_q8_1_0_1_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_q8_1_0_2_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_q8_1_0_3_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_q8_1_0_4_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_q8_1_0_5_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_q8_1_0_6_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_q8_1_0_7_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_q8_1_1_0_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_q8_1_1_1_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_q8_1_1_2_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_q8_1_1_3_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_q8_1_1_4_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_q8_1_1_5_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_q8_1_1_6_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_q8_1_1_7_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_q8_1_2_0_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_q8_1_2_1_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_q8_1_2_2_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_q8_1_2_3_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_q8_1_2_4_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_q8_1_2_5_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_q8_1_2_6_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_q8_1_2_7_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (n % 3) {
|
|
case 0:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_q8_2_0_0_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_q8_2_0_1_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_q8_2_0_2_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_q8_2_0_3_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_q8_2_0_4_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_q8_2_0_5_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_q8_2_0_6_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_q8_2_0_7_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_q8_2_1_0_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_q8_2_1_1_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_q8_2_1_2_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_q8_2_1_3_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_q8_2_1_4_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_q8_2_1_5_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_q8_2_1_6_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_q8_2_1_7_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_q8_2_2_0_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_q8_2_2_1_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_q8_2_2_2_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_q8_2_2_3_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_q8_2_2_4_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_q8_2_2_5_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_q8_2_2_6_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_q8_2_2_7_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
} else {
|
|
switch (m % 3) {
|
|
case 0:
|
|
switch (n % 3) {
|
|
case 0:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_q8_0_0_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_q8_0_0_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_q8_0_0_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_q8_0_0_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_q8_0_0_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_q8_0_0_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_q8_0_0_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_q8_0_0_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_q8_0_1_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_q8_0_1_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_q8_0_1_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_q8_0_1_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_q8_0_1_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_q8_0_1_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_q8_0_1_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_q8_0_1_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_q8_0_2_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_q8_0_2_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_q8_0_2_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_q8_0_2_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_q8_0_2_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_q8_0_2_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_q8_0_2_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_q8_0_2_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (n % 3) {
|
|
case 0:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_q8_1_0_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_q8_1_0_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_q8_1_0_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_q8_1_0_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_q8_1_0_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_q8_1_0_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_q8_1_0_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_q8_1_0_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_q8_1_1_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_q8_1_1_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_q8_1_1_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_q8_1_1_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_q8_1_1_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_q8_1_1_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_q8_1_1_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_q8_1_1_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_q8_1_2_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_q8_1_2_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_q8_1_2_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_q8_1_2_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_q8_1_2_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_q8_1_2_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_q8_1_2_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_q8_1_2_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (n % 3) {
|
|
case 0:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_q8_2_0_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_q8_2_0_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_q8_2_0_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_q8_2_0_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_q8_2_0_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_q8_2_0_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_q8_2_0_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_q8_2_0_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_q8_2_1_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_q8_2_1_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_q8_2_1_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_q8_2_1_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_q8_2_1_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_q8_2_1_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_q8_2_1_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_q8_2_1_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_q8_2_2_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_q8_2_2_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_q8_2_2_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_q8_2_2_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_q8_2_2_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_q8_2_2_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_q8_2_2_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_q8_2_2_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_offset,
|
|
multiplicative_offset, shift, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
void gemm_i32_strided(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, std::int32_t* result,
|
|
std::int32_t result_stride) {
|
|
const bool lhs_aligned = ((reinterpret_cast<std::uintptr_t>(lhs) % 8) == 0);
|
|
const bool rhs_aligned = ((reinterpret_cast<std::uintptr_t>(rhs) % 8) == 0);
|
|
const bool k_aligned = ((k % 8) == 0);
|
|
const bool aligned = lhs_aligned && rhs_aligned && k_aligned;
|
|
if (aligned) {
|
|
switch (m % 3) {
|
|
case 0:
|
|
switch (n % 3) {
|
|
case 0:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_i32_0_0_0_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_i32_0_0_1_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_i32_0_0_2_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_i32_0_0_3_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_i32_0_0_4_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_i32_0_0_5_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_i32_0_0_6_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_i32_0_0_7_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_i32_0_1_0_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_i32_0_1_1_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_i32_0_1_2_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_i32_0_1_3_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_i32_0_1_4_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_i32_0_1_5_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_i32_0_1_6_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_i32_0_1_7_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_i32_0_2_0_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_i32_0_2_1_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_i32_0_2_2_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_i32_0_2_3_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_i32_0_2_4_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_i32_0_2_5_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_i32_0_2_6_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_i32_0_2_7_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (n % 3) {
|
|
case 0:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_i32_1_0_0_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_i32_1_0_1_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_i32_1_0_2_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_i32_1_0_3_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_i32_1_0_4_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_i32_1_0_5_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_i32_1_0_6_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_i32_1_0_7_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_i32_1_1_0_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_i32_1_1_1_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_i32_1_1_2_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_i32_1_1_3_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_i32_1_1_4_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_i32_1_1_5_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_i32_1_1_6_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_i32_1_1_7_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_i32_1_2_0_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_i32_1_2_1_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_i32_1_2_2_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_i32_1_2_3_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_i32_1_2_4_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_i32_1_2_5_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_i32_1_2_6_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_i32_1_2_7_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (n % 3) {
|
|
case 0:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_i32_2_0_0_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_i32_2_0_1_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_i32_2_0_2_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_i32_2_0_3_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_i32_2_0_4_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_i32_2_0_5_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_i32_2_0_6_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_i32_2_0_7_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_i32_2_1_0_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_i32_2_1_1_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_i32_2_1_2_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_i32_2_1_3_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_i32_2_1_4_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_i32_2_1_5_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_i32_2_1_6_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_i32_2_1_7_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_i32_2_2_0_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_i32_2_2_1_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_i32_2_2_2_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_i32_2_2_3_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_i32_2_2_4_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_i32_2_2_5_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_i32_2_2_6_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_i32_2_2_7_aligned(scratch, lhs, rhs, m, n, k,
|
|
lhs_offset, rhs_offset, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
} else {
|
|
switch (m % 3) {
|
|
case 0:
|
|
switch (n % 3) {
|
|
case 0:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_i32_0_0_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_i32_0_0_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_i32_0_0_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_i32_0_0_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_i32_0_0_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_i32_0_0_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_i32_0_0_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_i32_0_0_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_i32_0_1_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_i32_0_1_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_i32_0_1_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_i32_0_1_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_i32_0_1_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_i32_0_1_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_i32_0_1_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_i32_0_1_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_i32_0_2_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_i32_0_2_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_i32_0_2_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_i32_0_2_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_i32_0_2_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_i32_0_2_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_i32_0_2_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_i32_0_2_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (n % 3) {
|
|
case 0:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_i32_1_0_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_i32_1_0_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_i32_1_0_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_i32_1_0_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_i32_1_0_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_i32_1_0_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_i32_1_0_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_i32_1_0_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_i32_1_1_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_i32_1_1_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_i32_1_1_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_i32_1_1_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_i32_1_1_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_i32_1_1_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_i32_1_1_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_i32_1_1_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_i32_1_2_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_i32_1_2_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_i32_1_2_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_i32_1_2_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_i32_1_2_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_i32_1_2_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_i32_1_2_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_i32_1_2_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (n % 3) {
|
|
case 0:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_i32_2_0_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_i32_2_0_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_i32_2_0_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_i32_2_0_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_i32_2_0_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_i32_2_0_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_i32_2_0_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_i32_2_0_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_i32_2_1_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_i32_2_1_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_i32_2_1_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_i32_2_1_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_i32_2_1_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_i32_2_1_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_i32_2_1_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_i32_2_1_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_i32_2_2_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_i32_2_2_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_i32_2_2_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_i32_2_2_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_i32_2_2_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_i32_2_2_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_i32_2_2_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_i32_2_2_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result, result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
void gemm_f_strided(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset,
|
|
std::int32_t rhs_offset, float result_scale, float* result,
|
|
std::int32_t result_stride) {
|
|
const bool lhs_aligned = ((reinterpret_cast<std::uintptr_t>(lhs) % 8) == 0);
|
|
const bool rhs_aligned = ((reinterpret_cast<std::uintptr_t>(rhs) % 8) == 0);
|
|
const bool k_aligned = ((k % 8) == 0);
|
|
const bool aligned = lhs_aligned && rhs_aligned && k_aligned;
|
|
if (aligned) {
|
|
switch (m % 3) {
|
|
case 0:
|
|
switch (n % 3) {
|
|
case 0:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_f_0_0_0_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_f_0_0_1_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_f_0_0_2_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_f_0_0_3_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_f_0_0_4_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_f_0_0_5_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_f_0_0_6_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_f_0_0_7_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_f_0_1_0_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_f_0_1_1_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_f_0_1_2_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_f_0_1_3_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_f_0_1_4_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_f_0_1_5_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_f_0_1_6_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_f_0_1_7_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_f_0_2_0_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_f_0_2_1_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_f_0_2_2_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_f_0_2_3_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_f_0_2_4_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_f_0_2_5_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_f_0_2_6_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_f_0_2_7_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (n % 3) {
|
|
case 0:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_f_1_0_0_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_f_1_0_1_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_f_1_0_2_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_f_1_0_3_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_f_1_0_4_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_f_1_0_5_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_f_1_0_6_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_f_1_0_7_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_f_1_1_0_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_f_1_1_1_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_f_1_1_2_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_f_1_1_3_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_f_1_1_4_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_f_1_1_5_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_f_1_1_6_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_f_1_1_7_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_f_1_2_0_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_f_1_2_1_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_f_1_2_2_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_f_1_2_3_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_f_1_2_4_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_f_1_2_5_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_f_1_2_6_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_f_1_2_7_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (n % 3) {
|
|
case 0:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_f_2_0_0_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_f_2_0_1_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_f_2_0_2_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_f_2_0_3_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_f_2_0_4_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_f_2_0_5_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_f_2_0_6_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_f_2_0_7_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_f_2_1_0_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_f_2_1_1_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_f_2_1_2_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_f_2_1_3_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_f_2_1_4_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_f_2_1_5_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_f_2_1_6_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_f_2_1_7_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_f_2_2_0_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_f_2_2_1_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_f_2_2_2_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_f_2_2_3_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_f_2_2_4_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_f_2_2_5_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_f_2_2_6_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_f_2_2_7_aligned(
|
|
scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
} else {
|
|
switch (m % 3) {
|
|
case 0:
|
|
switch (n % 3) {
|
|
case 0:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_f_0_0_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_f_0_0_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_f_0_0_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_f_0_0_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_f_0_0_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_f_0_0_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_f_0_0_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_f_0_0_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_f_0_1_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_f_0_1_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_f_0_1_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_f_0_1_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_f_0_1_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_f_0_1_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_f_0_1_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_f_0_1_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_f_0_2_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_f_0_2_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_f_0_2_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_f_0_2_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_f_0_2_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_f_0_2_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_f_0_2_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_f_0_2_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (n % 3) {
|
|
case 0:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_f_1_0_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_f_1_0_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_f_1_0_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_f_1_0_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_f_1_0_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_f_1_0_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_f_1_0_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_f_1_0_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_f_1_1_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_f_1_1_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_f_1_1_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_f_1_1_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_f_1_1_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_f_1_1_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_f_1_1_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_f_1_1_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_f_1_2_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_f_1_2_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_f_1_2_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_f_1_2_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_f_1_2_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_f_1_2_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_f_1_2_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_f_1_2_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (n % 3) {
|
|
case 0:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_f_2_0_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_f_2_0_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_f_2_0_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_f_2_0_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_f_2_0_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_f_2_0_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_f_2_0_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_f_2_0_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 1:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_f_2_1_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_f_2_1_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_f_2_1_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_f_2_1_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_f_2_1_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_f_2_1_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_f_2_1_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_f_2_1_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
case 2:
|
|
switch (k % 8) {
|
|
case 0:
|
|
internal::gemm_f_2_2_0(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 1:
|
|
internal::gemm_f_2_2_1(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 2:
|
|
internal::gemm_f_2_2_2(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 3:
|
|
internal::gemm_f_2_2_3(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 4:
|
|
internal::gemm_f_2_2_4(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 5:
|
|
internal::gemm_f_2_2_5(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 6:
|
|
internal::gemm_f_2_2_6(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
case 7:
|
|
internal::gemm_f_2_2_7(scratch, lhs, rhs, m, n, k, lhs_offset,
|
|
rhs_offset, result_scale, result,
|
|
result_stride);
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
void gemm_q8(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t result_offset, std::int32_t multiplicative_offset,
|
|
std::int32_t shift, std::uint8_t* result) {
|
|
gemm_q8_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_offset, multiplicative_offset, shift, result, n);
|
|
}
|
|
|
|
void gemm_i32(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
std::int32_t* result) {
|
|
gemm_i32_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset, result,
|
|
n);
|
|
}
|
|
|
|
void gemm_f(std::uint8_t* scratch, const std::uint8_t* lhs,
|
|
const std::uint8_t* rhs, std::int32_t m, std::int32_t n,
|
|
std::int32_t k, std::int32_t lhs_offset, std::int32_t rhs_offset,
|
|
float result_scale, float* result) {
|
|
gemm_f_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
|
|
result_scale, result, n);
|
|
}
|
|
|
|
} // namespace meta
|
|
|
|
} // namespace gemmlowp
|
|
|
|
#else
|
|
#warning "Meta gemm fast-path requires GEMMLOWP_NEON_32!"
|
|
#endif
|
|
|
|
#endif // GEMMLOWP_META_SINGLE_THREAD_GEMM_H_
|