diff --git a/dlk/python/dlk/templates/Makefile.tpl b/dlk/python/dlk/templates/Makefile.tpl index 16ac549f5..a4831a17e 100644 --- a/dlk/python/dlk/templates/Makefile.tpl +++ b/dlk/python/dlk/templates/Makefile.tpl @@ -72,6 +72,15 @@ LIB_X86_SRC := \ $(SRC_DIR)/func/impl/generic/pack_16bit.cpp LIB_X86_OBJ := $(patsubst %.cpp, %.o, $(LIB_X86_SRC)) +LIB_X86_AVX_SRC := \ + $(SRC_DIR)/func/generic/batch_normalization.cpp \ + $(SRC_DIR)/func/impl/x86_avx/quantized_conv2d_tiling.cpp \ + $(SRC_DIR)/matrix/generic/quantized_multiplication.cpp \ + $(SRC_DIR)/func/impl/generic/pop_count.cpp \ + $(SRC_DIR)/func/impl/x86_avx/apply_thresholds.cpp \ + $(SRC_DIR)/func/impl/x86_avx/pack_16bit.cpp +LIB_X86_AVX_OBJ := $(patsubst %.cpp, %.o, $(LIB_X86_AVX_SRC)) + LIB_OBJ := $(patsubst %.cpp, %.o, $(LIB_SRC)) OBJ := $(patsubst %.cpp, %.o, $(SRC)) @@ -81,6 +90,8 @@ HLS_INCLUDE := -I./hls/include TARGETS_X86 := lm_x86 +TARGETS_X86_AVX := lm_x86_avx + TARGETS_AARCH64 := lm_aarch64 TARGETS_ARM := lm_arm @@ -89,6 +100,8 @@ TARGETS_FPGA := lm_fpga LIBS_X86 := lib_x86 +LIBS_X86_AVX := lib_x86_avx + LIBS_AARCH64 := lib_aarch64 LIBS_ARM := lib_arm @@ -97,6 +110,8 @@ LIBS_FPGA := lib_fpga ARS_X86 := ar_x86 +ARS_X86_AVX := ar_x86_avx + ARS_AARCH64 := ar_aarch64 ARS_X86 := ar_x86 @@ -136,6 +151,10 @@ lm_x86: CXX = g++ lm_x86: FLAGS += $(INCLUDES) -O3 -std=c++14 -DUSE_PNG -pthread -g lm_x86: CXXFLAGS += +lm_x86_avx: CXX = g++ +lm_x86_avx: FLAGS += $(INCLUDES) -O3 -std=c++14 -mavx2 -DUSE_AVX -DUSE_PNG -pthread -g -fopenmp +lm_x86_avx: CXXFLAGS += + lm_aarch64: CXX = aarch64-linux-gnu-g++ lm_aarch64: FLAGS += $(INCLUDES) -std=c++14 -O3 -DUSE_NEON -DUSE_PNG -pthread -g -fopenmp lm_aarch64: CXXFLAGS += @@ -152,6 +171,10 @@ lib_x86: CXX = g++ lib_x86: FLAGS += $(INCLUDES) -O3 -std=c++14 -fPIC -fvisibility=hidden -pthread -g lib_x86: CXXFLAGS += +lib_x86_avx: CXX = g++ +lib_x86_avx: FLAGS += $(INCLUDES) -O3 -std=c++14 -fPIC -fvisibility=hidden -DUSE_AVX -pthread -g -fopenmp +lib_x86_avx: CXXFLAGS += + lib_aarch64: CXX = aarch64-linux-gnu-g++ lib_aarch64: FLAGS += $(INCLUDES) -O3 -std=c++14 -fPIC -fvisibility=hidden -DUSE_NEON -pthread -g lib_aarch64: CXXFLAGS += @@ -170,6 +193,12 @@ ar_x86: FLAGS += $(INCLUDES) -O3 -std=c++14 -fPIC -fvisibility=hidden ar_x86: LDFLAGS += -rcs ar_x86: NAME = x86 +ar_x86_avx: AR = ar +ar_x86_avx: CXX = g++ +ar_x86_avx: FLAGS += $(INCLUDES) -O3 -std=c++14 -fPIC -fvisibility=hidden -DUSE_AVX -pthread -g -fopenmp +ar_x86_avx: LDFLAGS += -rcs +ar_x86_avx: NAME = x86_avx + ar_aarch64: AR = aarch64-linux-gnu-ar ar_aarch64: CXX = aarch64-linux-gnu-g++ ar_aarch64: FLAGS += $(INCLUDES) -O3 -std=c++14 -fPIC -fvisibility=hidden -DUSE_NEON -pthread -g @@ -201,9 +230,15 @@ $(TARGETS_AARCH64): $(OBJ) $(TVM_OBJ) $(LIB_AARCH64_OBJ) $(TARGETS_X86): $(OBJ) $(TVM_OBJ) $(LIB_X86_OBJ) $(CXX) $(FLAGS) $(OBJ) $(TVM_OBJ) $(LIB_X86_OBJ) -o $@.elf $(CXXFLAGS) $(TVM_X86_LIBS) -pthread -ldl +$(TARGETS_X86_AVX): $(OBJ) $(TVM_OBJ) $(LIB_X86_AVX_OBJ) + $(CXX) $(FLAGS) $(OBJ) $(TVM_OBJ) $(LIB_X86_AVX_OBJ) -o $@.elf $(CXXFLAGS) $(TVM_X86_AVX_LIBS) -pthread -ldl + $(LIBS_X86): $(LIB_OBJ) $(TVM_OBJ) $(LIB_X86_OBJ) $(CXX) $(FLAGS) $(LIB_OBJ) $(TVM_OBJ) $(LIB_X86_OBJ) -o $@.so $(CXXFLAGS) $(TVM_X86_LIBS) -shared -pthread -ldl +$(LIBS_X86_AVX): $(LIB_OBJ) $(TVM_OBJ) $(LIB_X86_AVX_OBJ) + $(CXX) $(FLAGS) $(LIB_OBJ) $(TVM_OBJ) $(LIB_X86_AVX_OBJ) -o $@.so $(CXXFLAGS) $(TVM_X86_AVX_LIBS) -shared -pthread -ldl + $(LIBS_AARCH64): $(LIB_OBJ) $(TVM_OBJ) $(LIB_AARCH64_OBJ) $(CXX) $(FLAGS) $(LIB_OBJ) $(TVM_OBJ) $(LIB_AARCH64_OBJ) -o $@.so $(CXXFLAGS) $(TVM_AARCH64_LIBS) -shared -pthread -ldl @@ -216,6 +251,9 @@ $(LIBS_FPGA): $(LIB_OBJ) $(TVM_OBJ) $(LIB_FPGA_OBJ) $(ARS_X86): $(LIB_OBJ) $(TVM_OBJ) $(LIB_X86_OBJ) $(AR) $(LDFLAGS) libdlk_$(NAME).a $(LIB_OBJ) $(TVM_OBJ) $(TVM_X86_LIBS) $(LIB_X86_OBJ) +$(ARS_X86_AVX): $(LIB_OBJ) $(TVM_OBJ) $(LIB_X86_AVX_OBJ) + $(AR) $(LDFLAGS) libdlk_$(NAME).a $(LIB_OBJ) $(TVM_OBJ) $(TVM_X86_AVX_LIBS) $(LIB_X86_AVX_OBJ) + $(ARS_AARCH64): $(LIB_OBJ) $(TVM_OBJ) $(LIB_AARCH64_OBJ) $(AR) $(LDFLAGS) libdlk_$(NAME).a $(LIB_OBJ) $(TVM_OBJ) $(TVM_AARCH64_LIBS) $(LIB_AARCH64_OBJ) diff --git a/dlk/python/dlk/templates/include/func/quantized_conv2d.h b/dlk/python/dlk/templates/include/func/quantized_conv2d.h index d351397ac..50731eb10 100644 --- a/dlk/python/dlk/templates/include/func/quantized_conv2d.h +++ b/dlk/python/dlk/templates/include/func/quantized_conv2d.h @@ -60,7 +60,7 @@ void QuantizedConv2D(const TensorView& input, convert_tensor(input, tmp); Measurement::Stop(); dlk::impl::TCAConv2d(tmp, kernel, p); -#elif defined USE_NEON +#elif defined USE_NEON || defined USE_AVX dlk::impl::tiling_input_t::tensor_info_t shape = { ic / TilingInTypeBitWidth, ih, diff --git a/dlk/python/dlk/templates/include/tensor_view.h b/dlk/python/dlk/templates/include/tensor_view.h index 28601be3b..223b4ae59 100644 --- a/dlk/python/dlk/templates/include/tensor_view.h +++ b/dlk/python/dlk/templates/include/tensor_view.h @@ -245,7 +245,7 @@ class TensorView, memory_layout> { #ifdef RUN_ON_FPGA using kernel_t = TensorView; -#elif defined USE_NEON +#elif defined USE_NEON || defined USE_AVX using kernel_t = TensorView; #else using kernel_t = TensorView; diff --git a/dlk/python/dlk/templates/manual/consts/input.tpl.cpp b/dlk/python/dlk/templates/manual/consts/input.tpl.cpp index e1631cf2b..b577d4494 100644 --- a/dlk/python/dlk/templates/manual/consts/input.tpl.cpp +++ b/dlk/python/dlk/templates/manual/consts/input.tpl.cpp @@ -42,7 +42,7 @@ static constexpr decltype({{ node.name }})::tensor_info_t {{ node.n const TensorView<{{ node.dtype.cpptype() }}, MemoryLayout::{{ node.transposed_dimension_format }}> {{ node.name }}( reinterpret_cast<{{ node.dtype.cpptype() }}*>({{ node.name }}_raw), {{ node.name }}_shape); -#elif defined USE_NEON +#elif defined USE_NEON || defined USE_AVX static Base<{{ node.dtype.cpptype() }}>::type {{ node.name }}_raw[] = { {% for d in node.data.flatten() -%} {{- d -}}, diff --git a/dlk/python/dlk/templates/manual/consts/input.tpl.h b/dlk/python/dlk/templates/manual/consts/input.tpl.h index 7a022a13e..054a13e32 100644 --- a/dlk/python/dlk/templates/manual/consts/input.tpl.h +++ b/dlk/python/dlk/templates/manual/consts/input.tpl.h @@ -26,7 +26,7 @@ extern const TensorView<{{ node.dtype.cpptype() }}, MemoryLayout::Atom> {{ node. #ifdef RUN_ON_FPGA extern const TensorView<{{ node.dtype.cpptype() }}, MemoryLayout::{{ node.transposed_dimension_format }}> {{ node.name }}; -#elif defined USE_NEON +#elif defined USE_NEON || defined USE_AVX extern const TensorView<{{ node.dtype.cpptype() }}, MemoryLayout::{{ node.dimension}}> {{ node.name }}; #else extern const TensorView<{{ node.dtype.cpptype() }}, MemoryLayout::{{ node.kn2row_dimension_format }}> {{ node.name }}; diff --git a/dlk/python/dlk/templates/src/func/impl/x86_avx/apply_thresholds.cpp b/dlk/python/dlk/templates/src/func/impl/x86_avx/apply_thresholds.cpp new file mode 100644 index 000000000..6b8122038 --- /dev/null +++ b/dlk/python/dlk/templates/src/func/impl/x86_avx/apply_thresholds.cpp @@ -0,0 +1,139 @@ +/* Copyright 2018 The Blueoil Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "global.h" +#include "matrix_view.h" +#include "operators.h" // FIXME(nikolay): for binary_convolution_parameters definition, rid of it later +#include "time_measurement.h" + +#include + +namespace dlk { + +namespace impl { + +void ApplyThresholds( + dlk::MatrixView &result, + const binary_convolution_parameters &p) { + Measurement::Start("ApplyThresholds"); + + const auto buf_ts0 = std::make_unique(result.rows()); + const auto buf_ts1 = std::make_unique(result.rows()); + const auto buf_ts2 = std::make_unique(result.rows()); + const auto buf_flg = std::make_unique(result.rows()); + + for (unsigned int i = 0; i < result.rows(); ++i) { + T_INT ts0 = p.thresholds[NUM_OF_A2W1_THRESHOLD * i]; + T_INT ts1 = p.thresholds[NUM_OF_A2W1_THRESHOLD * i + 1]; + T_INT ts2 = p.thresholds[NUM_OF_A2W1_THRESHOLD * i + 2]; + T_INT flag = p.thresholds[NUM_OF_A2W1_THRESHOLD * i + 3]; + if (flag == -1) { + ++ts0; + ++ts1; + ++ts2; + } + buf_ts0[i] = ts0; + buf_ts1[i] = ts1; + buf_ts2[i] = ts2; + buf_flg[i] = flag; + } + +#pragma omp parallel for + for (unsigned int j = 0; j < result.cols(); ++j) { + for (unsigned int i = 0; i < result.rows(); i += 16) { + const auto d = _mm256_loadu_si256(reinterpret_cast<__m256i*>(result.data(i, j))); + const auto ts0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(buf_ts0.get() + i)); + const auto ts1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(buf_ts1.get() + i)); + const auto ts2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(buf_ts2.get() + i)); + const auto flg = _mm256_loadu_si256(reinterpret_cast<__m256i*>(buf_flg.get() + i)); + const auto f0 = _mm256_andnot_si256(_mm256_cmpgt_epi16(ts0, d), flg); + const auto f1 = _mm256_andnot_si256(_mm256_cmpgt_epi16(ts1, d), flg); + const auto f2 = _mm256_andnot_si256(_mm256_cmpgt_epi16(ts2, d), flg); + const auto is_neg = _mm256_cmpgt_epi16(_mm256_setzero_si256(), flg); + const auto tmp = _mm256_add_epi16(_mm256_add_epi16(f0, f1), _mm256_add_epi16(f2, is_neg)); + const auto m2 = _mm256_sub_epi16(flg, _mm256_set1_epi16(2)); + const auto is_not_const = _mm256_cmpgt_epi16(_mm256_setzero_si256(), m2); + const auto res = _mm256_blendv_epi8(m2, tmp, is_not_const); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(result.data(i, j)), res); + } + } + + Measurement::Stop(); +} + +void ApplyThresholdsAndPack( + dlk::MatrixView &result, + const binary_convolution_parameters &p, + QUANTIZED_PACKED output[]) { + Measurement::Start("ApplyThresholdsAndPack"); + + const auto buf_ts0 = std::make_unique(result.rows()); + const auto buf_ts1 = std::make_unique(result.rows()); + const auto buf_ts2 = std::make_unique(result.rows()); + const auto buf_flg = std::make_unique(result.rows()); + + for (unsigned int i = 0; i < result.rows(); ++i) { + T_INT ts0 = p.thresholds[NUM_OF_A2W1_THRESHOLD * i]; + T_INT ts1 = p.thresholds[NUM_OF_A2W1_THRESHOLD * i + 1]; + T_INT ts2 = p.thresholds[NUM_OF_A2W1_THRESHOLD * i + 2]; + T_INT flag = p.thresholds[NUM_OF_A2W1_THRESHOLD * i + 3]; + if (flag == -1) { + ++ts0; + ++ts1; + ++ts2; + } + buf_ts0[i] = ts0; + buf_ts1[i] = ts1; + buf_ts2[i] = ts2; + buf_flg[i] = flag; + } + +#pragma omp parallel for + for (unsigned int j = 0; j < result.cols(); ++j) { + for (unsigned int i = 0; i < result.rows(); i += 32) { +#define APPLY(k) \ + const auto d##k = _mm256_loadu_si256(reinterpret_cast<__m256i*>(result.data(i + k * 16, j))); \ + const auto ts0##k = _mm256_loadu_si256(reinterpret_cast<__m256i*>(buf_ts0.get() + i + k * 16)); \ + const auto ts1##k = _mm256_loadu_si256(reinterpret_cast<__m256i*>(buf_ts1.get() + i + k * 16)); \ + const auto ts2##k = _mm256_loadu_si256(reinterpret_cast<__m256i*>(buf_ts2.get() + i + k * 16)); \ + const auto flg##k = _mm256_loadu_si256(reinterpret_cast<__m256i*>(buf_flg.get() + i + k * 16)); \ + const auto f0##k = _mm256_andnot_si256(_mm256_cmpgt_epi16(ts0##k, d##k), flg##k); \ + const auto f1##k = _mm256_andnot_si256(_mm256_cmpgt_epi16(ts1##k, d##k), flg##k); \ + const auto f2##k = _mm256_andnot_si256(_mm256_cmpgt_epi16(ts2##k, d##k), flg##k); \ + const auto is_neg##k = _mm256_cmpgt_epi16(_mm256_setzero_si256(), flg##k); \ + const auto tmp##k = _mm256_add_epi16(_mm256_add_epi16(f0##k, f1##k), _mm256_add_epi16(f2##k, is_neg##k)); \ + const auto m2##k = _mm256_sub_epi16(flg##k, _mm256_set1_epi16(2)); \ + const auto is_not_const##k = _mm256_cmpgt_epi16(_mm256_setzero_si256(), m2##k); \ + const auto res##k = _mm256_blendv_epi8(m2##k, tmp##k, is_not_const##k); + APPLY(0) + APPLY(1) + const auto packed = _mm256_packs_epi16(res0, res1); + const auto permuted = _mm256_permute4x64_epi64(packed, 0xD8); + const auto vlsb = _mm256_slli_epi32(permuted, 7); + const auto vmsb = _mm256_slli_epi32(permuted, 6); + const auto lsb = _mm256_movemask_epi8(vlsb); + const auto msb = _mm256_movemask_epi8(vmsb); + const auto index = (j + (i / 32) * result.cols()) * 2; + output[index] = QUANTIZED_PACKED(lsb); + output[index+1] = QUANTIZED_PACKED(msb); + } + } + + Measurement::Stop(); +} + +} // namespace impl + +} // namespace dlk diff --git a/dlk/python/dlk/templates/src/func/impl/x86_avx/pack_16bit.cpp b/dlk/python/dlk/templates/src/func/impl/x86_avx/pack_16bit.cpp new file mode 100644 index 000000000..20c5f0f53 --- /dev/null +++ b/dlk/python/dlk/templates/src/func/impl/x86_avx/pack_16bit.cpp @@ -0,0 +1,51 @@ +/* Copyright 2019 The Blueoil Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "func/impl/pack_16bit.h" +#include + +#include + +#include "time_measurement.h" + +namespace dlk { + +namespace impl { + +void pack_16bit(const BIN_CONV_OUTPUT input[], QUANTIZED_PACKED output[], const std::size_t length) { + using base = QUANTIZED_PACKED::base_t; + const auto bits = QUANTIZED_PACKED::BitCount; + assert((length % bits) == 0); + Measurement::Start("pack bits"); + std::size_t j = 0; + for (std::size_t i = 0; i < length; i += bits) { + const auto v1 = _mm256_loadu_si256(reinterpret_cast(input + i)); + const auto v2 = _mm256_loadu_si256(reinterpret_cast(input + i + 16)); + const auto packed = _mm256_packs_epi16(v1, v2); + const auto permuted = _mm256_permute4x64_epi64(packed, 0xD8); + const auto vlsb = _mm256_slli_epi32(permuted, 7); + const auto vmsb = _mm256_slli_epi32(permuted, 6); + const auto lsb = _mm256_movemask_epi8(vlsb); + const auto msb = _mm256_movemask_epi8(vmsb); + output[j] = QUANTIZED_PACKED(lsb); + output[j+1] = QUANTIZED_PACKED(msb); + j += 2; + } + Measurement::Stop(); +} + +} // namespace impl + +} // namespace dlk diff --git a/dlk/python/dlk/templates/src/func/impl/x86_avx/quantized_conv2d_tiling.cpp b/dlk/python/dlk/templates/src/func/impl/x86_avx/quantized_conv2d_tiling.cpp new file mode 100644 index 000000000..b543785f3 --- /dev/null +++ b/dlk/python/dlk/templates/src/func/impl/x86_avx/quantized_conv2d_tiling.cpp @@ -0,0 +1,297 @@ +/* Copyright 2018 The Blueoil Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "global.h" +#include "func/impl/apply_thresholds.h" +#include "func/impl/quantized_conv2d_tiling.h" +#include "func/impl/pack_16bit.h" +#include "time_measurement.h" +#include "tensor_convert.h" + +#include + +#ifdef _OPENMP +#include +#endif + +namespace dlk { + +namespace impl { + +static const auto buf_th = std::make_unique(MAX_SIZE_QOUTPUTS_PER_LAYER); +static const auto buf_non_th = std::make_unique(MAX_SIZE_OUTPUTS_PER_LAYER); + +void pack_input_for_tiling(const TensorView& input, + const tiling_input_t& output) { + Measurement::Start("Pack_input_for_tiling"); + const T_UINT in_channels = input.get_shape()[3]; + const T_UINT in_height = input.get_shape()[1]; + const T_UINT in_width = input.get_shape()[2]; + const T_UINT in_bitwidth = output.get_shape()[3]; + + constexpr T_UINT InTypeBitWidth = CHAR_BIT * sizeof(uint32_t); + const T_UINT in_stride = (in_channels + InTypeBitWidth - 1) / InTypeBitWidth; +#pragma omp parallel for schedule(dynamic) + for (unsigned int in_ch_high = 0; in_ch_high < in_stride; ++in_ch_high) { + for (unsigned int row = 0; row < in_height; ++row) { + for (unsigned int col = 0; col < in_width; ++col) { + for (unsigned int in_bit_ch = 0; in_bit_ch < in_bitwidth; ++in_bit_ch) { + output(in_ch_high, row, col, in_bit_ch, 0) = tiling_input_elem_t(0); + } + } + } + } +#pragma omp parallel for schedule(dynamic) + for (unsigned int row = 0; row < in_height; ++row) { + for (unsigned int col = 0; col < in_width; ++col) { + for (unsigned int in_ch_high = 0; in_ch_high < in_channels; in_ch_high += InTypeBitWidth) { + for (unsigned int in_ch_low = 0; in_ch_low < InTypeBitWidth; ++in_ch_low) { + unsigned int in_ch = in_ch_high + in_ch_low; + if (in_ch >= in_channels) break; + QUANTIZED_NOT_PACKED val = input(0, row, col, in_ch); + for (unsigned int in_bit_ch = 0; in_bit_ch < in_bitwidth; ++in_bit_ch) { + tiling_input_elem_base_t bit = (val >> in_bit_ch) & 1; + output(in_ch_high / InTypeBitWidth, row, col, in_bit_ch, 0) |= tiling_input_elem_t(bit << in_ch_low); + } + } + } + } + } + + Measurement::Stop(); +} + +void QuantizedConv2DTiling(const tiling_input_t& input, + const kernel_t& kernel, + const binary_convolution_parameters &p) { + constexpr T_UINT InTypeBitWidth = tiling_input_elem_t::BitCount; + convolution_parameters cp = p.normal_conv_params; + const T_UINT out_channels = cp.output_channels; + const T_UINT kh = cp.kernel_height; + const T_UINT kw = cp.kernel_width; + const T_UINT in_bitwidth = 2; + const T_UINT in_channels = cp.kernel_depth; + const T_UINT in_height = cp.input_height; + const T_UINT in_width = cp.input_width; + const T_UINT in_stride = (in_channels + InTypeBitWidth - 1) / InTypeBitWidth; + const T_UINT padding = cp.padding; + const T_UINT out_height = cp.output_height; + const T_UINT out_width = cp.output_width; + const T_UINT out_size = out_height * out_width * out_channels; + + //assert(kh * kw < 32); + assert(in_height * in_width == out_height * out_width); + assert((in_channels % InTypeBitWidth) == 0); + + const T_UINT TileHeight = std::min(in_height, T_UINT(32)); // configurable + const T_UINT TileWidth = std::min(in_width + (in_width & 1), T_UINT(32)); // configurable + constexpr T_UINT InChUnroll = InTypeBitWidth; // hardcoded, not configurable + constexpr T_UINT OutChUnroll = 16; // hardcoded, not configurable + constexpr T_UINT InBitChUnroll = 2; // hardcoded, not configurable + constexpr T_UINT ColUnroll = 2; // hardcoded, not configurable + + const T_UINT row_tile_count = (in_height + TileHeight - 1) / TileHeight; + const T_UINT col_tile_count = (in_width + TileWidth - 1) / TileWidth; + const T_UINT out_tile_count = (out_channels + OutChUnroll - 1) / OutChUnroll; + const T_UINT total_tile_count = row_tile_count * col_tile_count * out_tile_count; + Measurement::Start("Quantized Conv2D Tiling"); + const auto mask4 = _mm256_set1_epi8(0x0F); + const auto popc_table = _mm256_setr_epi8( + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4 + ); + const auto vone = _mm256_set1_epi8(0x01); + const auto vone_16 = _mm256_set1_epi16(0x0001); +#pragma omp parallel for schedule(guided) + for (T_UINT tile_index = 0; tile_index < total_tile_count; ++tile_index) { + T_UINT out_ch_high = tile_index % out_tile_count * OutChUnroll; + T_UINT col_high = (tile_index / out_tile_count) % col_tile_count * TileWidth; + T_UINT row_high = tile_index / (out_tile_count * col_tile_count) * TileHeight; + int16_t out_tile[TileHeight][TileWidth][OutChUnroll]; + for (unsigned int row = 0; row < TileHeight; ++row) { + for (unsigned int col = 0; col < TileWidth; ++col) { + for (unsigned int out_ch = 0; out_ch < OutChUnroll; ++out_ch) { + out_tile[row][col][out_ch] = 0; + } + } + } + for (unsigned int in_ch_high = 0; in_ch_high < in_channels; in_ch_high += InTypeBitWidth) { + QUANTIZED_PACKED_KERNEL notk[kh][kw][OutChUnroll]; + int16_t notsum[OutChUnroll] = {}; + for (unsigned int out_ch = 0; out_ch < OutChUnroll; ++out_ch) { + notsum[out_ch] = 0; + for (unsigned int kr = 0; kr < kh; ++kr) { + for (unsigned int kc = 0; kc < kw; ++kc) { + const auto index = (out_ch_high + out_ch) * kh * kw * (in_channels / InTypeBitWidth) + + kr * kw * (in_channels / InTypeBitWidth) + + kc * (in_channels / InTypeBitWidth) + + (in_ch_high / InTypeBitWidth); + notk[kr][kc][out_ch] = kernel.data()[index]; + notsum[out_ch] += pop_count(notk[kr][kc][out_ch]); + } + } + } + for (unsigned int in_bit_ch_high = 0; in_bit_ch_high < in_bitwidth; in_bit_ch_high += InBitChUnroll) { + tiling_input_elem_t in_tile[TileHeight + kh - 1][TileWidth + kw - 1][InBitChUnroll]; + for (unsigned int row = 0; row < TileHeight + kh - 1; ++row) { + if (row_high + row >= in_height + 2*padding) break; + for (unsigned int col = 0; col < TileWidth + kw - 1; ++col) { + if (col_high + col >= in_width + 2*padding) break; + for (unsigned int in_bit_ch = 0; in_bit_ch < InBitChUnroll; ++in_bit_ch) { + if (row_high + row < padding || row_high + row >= in_height + padding + || col_high + col < padding || col_high + col >= in_width + padding) { + in_tile[row][col][in_bit_ch] = tiling_input_elem_t(0); + } else { + const auto index = (in_ch_high / InTypeBitWidth) * in_height * in_width * in_bitwidth + + (row_high + row - padding) * in_width * in_bitwidth + + (col_high + col - padding) * in_bitwidth + + (in_bit_ch_high + in_bit_ch); + in_tile[row][col][in_bit_ch] = input.data()[index]; + } + } + } + } + for (unsigned int row = 0; row < TileHeight; ++row) { + for (unsigned int col = 0; col < TileWidth; col += ColUnroll) { + auto xnorsum000 = _mm256_setzero_si256(); + auto xnorsum001 = _mm256_setzero_si256(); + auto xnorsum010 = _mm256_setzero_si256(); + auto xnorsum011 = _mm256_setzero_si256(); + auto xnorsum100 = _mm256_setzero_si256(); + auto xnorsum101 = _mm256_setzero_si256(); + auto xnorsum110 = _mm256_setzero_si256(); + auto xnorsum111 = _mm256_setzero_si256(); + for (unsigned int kr = 0; kr < kh; ++kr) { + auto in00 = _mm256_set1_epi32(in_tile[row + kr][col][0].Raw()); + auto in10 = _mm256_set1_epi32(in_tile[row + kr][col][1].Raw()); + for (unsigned int kc = 0; kc < kw; ++kc) { + const auto nk0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(¬k[kr][kc][ 0])); + const auto nk1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(¬k[kr][kc][ 8])); + const auto in01 = _mm256_set1_epi32(in_tile[row + kr][col + kc + 1][0].Raw()); +#define BINDP(i, j, k) \ + do { \ + const auto xnor = in##i##j ^ nk##k; \ + const auto l4 = mask4 & xnor; \ + const auto popc_l4 = _mm256_shuffle_epi8(popc_table, l4); \ + const auto h4 = mask4 & _mm256_srli_epi32(xnor, 4); \ + const auto popc_h4 = _mm256_shuffle_epi8(popc_table, h4); \ + const auto cnt = _mm256_add_epi8(popc_l4, popc_h4); \ + xnorsum##i##j##k = _mm256_add_epi8(xnorsum##i##j##k, cnt); \ + } while (0); + BINDP(0, 0, 0); + BINDP(0, 0, 1); + BINDP(0, 1, 0); + BINDP(0, 1, 1); + in00 = in01; + const auto in11 = _mm256_set1_epi32(in_tile[row + kr][col + kc + 1][1].Raw()); + BINDP(1, 0, 0); + BINDP(1, 0, 1); + BINDP(1, 1, 0); + BINDP(1, 1, 1); + in10 = in11; + } + } + const auto psum0000 = _mm256_maddubs_epi16(xnorsum000, vone); + const auto psum0001 = _mm256_maddubs_epi16(xnorsum001, vone); + const auto psum0010 = _mm256_maddubs_epi16(xnorsum010, vone); + const auto psum0011 = _mm256_maddubs_epi16(xnorsum011, vone); + const auto psum0100 = _mm256_maddubs_epi16(xnorsum100, vone); + const auto psum0101 = _mm256_maddubs_epi16(xnorsum101, vone); + const auto psum0110 = _mm256_maddubs_epi16(xnorsum110, vone); + const auto psum0111 = _mm256_maddubs_epi16(xnorsum111, vone); + const auto psum1000 = _mm256_madd_epi16(psum0000, vone_16); + const auto psum1001 = _mm256_madd_epi16(psum0001, vone_16); + const auto psum1010 = _mm256_madd_epi16(psum0010, vone_16); + const auto psum1011 = _mm256_madd_epi16(psum0011, vone_16); + const auto psum1100 = _mm256_madd_epi16(psum0100, vone_16); + const auto psum1101 = _mm256_madd_epi16(psum0101, vone_16); + const auto psum1110 = _mm256_madd_epi16(psum0110, vone_16); + const auto psum1111 = _mm256_madd_epi16(psum0111, vone_16); + const auto usum000 = _mm256_packs_epi32(psum1000, psum1001); + const auto usum001 = _mm256_packs_epi32(psum1010, psum1011); + const auto usum010 = _mm256_packs_epi32(psum1100, psum1101); + const auto usum011 = _mm256_packs_epi32(psum1110, psum1111); + const auto usum100 = _mm256_permute4x64_epi64(usum000, 0xD8); + const auto usum101 = _mm256_permute4x64_epi64(usum001, 0xD8); + const auto usum110 = _mm256_permute4x64_epi64(usum010, 0xD8); + const auto usum111 = _mm256_permute4x64_epi64(usum011, 0xD8); + const auto tmp0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(&out_tile[row][col + 0][0])); + const auto tmp1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(&out_tile[row][col + 1][0])); + const auto nsum = _mm256_loadu_si256(reinterpret_cast<__m256i*>(¬sum[0])); + const auto diff00 = _mm256_sub_epi16(usum100, nsum); + const auto diff01 = _mm256_sub_epi16(usum101, nsum); + const auto diff10 = _mm256_sub_epi16(usum110, nsum); + const auto diff11 = _mm256_sub_epi16(usum111, nsum); + const auto shifted00 = _mm256_slli_epi16(diff00, in_bit_ch_high); + const auto shifted01 = _mm256_slli_epi16(diff01, in_bit_ch_high); + const auto shifted10 = _mm256_slli_epi16(diff10, in_bit_ch_high + 1); + const auto shifted11 = _mm256_slli_epi16(diff11, in_bit_ch_high + 1); + const auto res0 = _mm256_add_epi16(tmp0, _mm256_add_epi16(shifted00, shifted10)); + const auto res1 = _mm256_add_epi16(tmp1, _mm256_add_epi16(shifted01, shifted11)); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(&out_tile[row][col + 0][0]), res0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(&out_tile[row][col + 1][0]), res1); + } + } + } + } + for (unsigned int row = 0; row < TileHeight; ++row) { + if (row_high + row >= out_height) break; + for (unsigned int col = 0; col < TileWidth; ++col) { + if (col_high + col >= out_width) break; + for (unsigned int out_ch = 0; out_ch < OutChUnroll; ++out_ch) { + unsigned int index = (row_high + row) * out_width * out_channels + + (col_high + col) * out_channels + + (out_ch_high + out_ch); + p.device_output_buf[index] = out_tile[row][col][out_ch]; + } + } + } + } + Measurement::Stop(); + + using namespace dlk; + auto output_ = MatrixView( + p.device_output_buf, out_channels, in_height * in_width); + + if (p.thresholds != nullptr) { + ApplyThresholdsAndPack(output_, p, buf_th.get()); + Measurement::Start("copy"); + std::copy(buf_th.get(), buf_th.get() + out_size / 32 * 2, (QUANTIZED_PACKED*)p.device_output_buf); + Measurement::Stop(); + } else { + const std::size_t b = 32; + Measurement::Start("copy"); + std::copy(p.device_output_buf, p.device_output_buf + out_size, buf_non_th.get()); + Measurement::Stop(); + TensorView::tensor_info_t buf_shape = { + out_height, out_width, out_channels + }; + TensorView buf_tensor(buf_non_th.get(), buf_shape); + TensorView::tensor_info_t out_shape = { + (out_channels + b - 1) / b, out_height, out_width, b + }; + TensorView out(p.device_output_buf, out_shape); + Measurement::Start("Output tensor convert"); + convert_tensor(buf_tensor, out); + Measurement::Stop(); + } +} + +} // namespace impl + +} // namespace dlk