Skip to content
This repository has been archived by the owner on Dec 1, 2021. It is now read-only.

Commit

Permalink
Optimize for x86 AVX (#365)
Browse files Browse the repository at this point in the history
* Optimize for x86 avx

* Optimize apply_thresholds for x86 avx

* Optimize pack_16bit

* Use guided

* Optimize apply_thresholds, pack_16bit, convert_tensor for avx
  • Loading branch information
primenumber authored and iizukak committed Jul 18, 2019
1 parent 12a763d commit 5abf157
Show file tree
Hide file tree
Showing 8 changed files with 529 additions and 4 deletions.
38 changes: 38 additions & 0 deletions dlk/python/dlk/templates/Makefile.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ LIB_X86_SRC := \
$(SRC_DIR)/func/impl/generic/pack_16bit.cpp
LIB_X86_OBJ := $(patsubst %.cpp, %.o, $(LIB_X86_SRC))

LIB_X86_AVX_SRC := \
$(SRC_DIR)/func/generic/batch_normalization.cpp \
$(SRC_DIR)/func/impl/x86_avx/quantized_conv2d_tiling.cpp \
$(SRC_DIR)/matrix/generic/quantized_multiplication.cpp \
$(SRC_DIR)/func/impl/generic/pop_count.cpp \
$(SRC_DIR)/func/impl/x86_avx/apply_thresholds.cpp \
$(SRC_DIR)/func/impl/x86_avx/pack_16bit.cpp
LIB_X86_AVX_OBJ := $(patsubst %.cpp, %.o, $(LIB_X86_AVX_SRC))

LIB_OBJ := $(patsubst %.cpp, %.o, $(LIB_SRC))
OBJ := $(patsubst %.cpp, %.o, $(SRC))

Expand All @@ -81,6 +90,8 @@ HLS_INCLUDE := -I./hls/include

TARGETS_X86 := lm_x86

TARGETS_X86_AVX := lm_x86_avx

TARGETS_AARCH64 := lm_aarch64

TARGETS_ARM := lm_arm
Expand All @@ -89,6 +100,8 @@ TARGETS_FPGA := lm_fpga

LIBS_X86 := lib_x86

LIBS_X86_AVX := lib_x86_avx

LIBS_AARCH64 := lib_aarch64

LIBS_ARM := lib_arm
Expand All @@ -97,6 +110,8 @@ LIBS_FPGA := lib_fpga

ARS_X86 := ar_x86

ARS_X86_AVX := ar_x86_avx

ARS_AARCH64 := ar_aarch64

ARS_X86 := ar_x86
Expand Down Expand Up @@ -136,6 +151,10 @@ lm_x86: CXX = g++
lm_x86: FLAGS += $(INCLUDES) -O3 -std=c++14 -DUSE_PNG -pthread -g
lm_x86: CXXFLAGS +=

lm_x86_avx: CXX = g++
lm_x86_avx: FLAGS += $(INCLUDES) -O3 -std=c++14 -mavx2 -DUSE_AVX -DUSE_PNG -pthread -g -fopenmp
lm_x86_avx: CXXFLAGS +=

lm_aarch64: CXX = aarch64-linux-gnu-g++
lm_aarch64: FLAGS += $(INCLUDES) -std=c++14 -O3 -DUSE_NEON -DUSE_PNG -pthread -g -fopenmp
lm_aarch64: CXXFLAGS +=
Expand All @@ -152,6 +171,10 @@ lib_x86: CXX = g++
lib_x86: FLAGS += $(INCLUDES) -O3 -std=c++14 -fPIC -fvisibility=hidden -pthread -g
lib_x86: CXXFLAGS +=

lib_x86_avx: CXX = g++
lib_x86_avx: FLAGS += $(INCLUDES) -O3 -std=c++14 -fPIC -fvisibility=hidden -DUSE_AVX -pthread -g -fopenmp
lib_x86_avx: CXXFLAGS +=

lib_aarch64: CXX = aarch64-linux-gnu-g++
lib_aarch64: FLAGS += $(INCLUDES) -O3 -std=c++14 -fPIC -fvisibility=hidden -DUSE_NEON -pthread -g
lib_aarch64: CXXFLAGS +=
Expand All @@ -170,6 +193,12 @@ ar_x86: FLAGS += $(INCLUDES) -O3 -std=c++14 -fPIC -fvisibility=hidden
ar_x86: LDFLAGS += -rcs
ar_x86: NAME = x86

ar_x86_avx: AR = ar
ar_x86_avx: CXX = g++
ar_x86_avx: FLAGS += $(INCLUDES) -O3 -std=c++14 -fPIC -fvisibility=hidden -DUSE_AVX -pthread -g -fopenmp
ar_x86_avx: LDFLAGS += -rcs
ar_x86_avx: NAME = x86_avx

ar_aarch64: AR = aarch64-linux-gnu-ar
ar_aarch64: CXX = aarch64-linux-gnu-g++
ar_aarch64: FLAGS += $(INCLUDES) -O3 -std=c++14 -fPIC -fvisibility=hidden -DUSE_NEON -pthread -g
Expand Down Expand Up @@ -201,9 +230,15 @@ $(TARGETS_AARCH64): $(OBJ) $(TVM_OBJ) $(LIB_AARCH64_OBJ)
$(TARGETS_X86): $(OBJ) $(TVM_OBJ) $(LIB_X86_OBJ)
$(CXX) $(FLAGS) $(OBJ) $(TVM_OBJ) $(LIB_X86_OBJ) -o $@.elf $(CXXFLAGS) $(TVM_X86_LIBS) -pthread -ldl

$(TARGETS_X86_AVX): $(OBJ) $(TVM_OBJ) $(LIB_X86_AVX_OBJ)
$(CXX) $(FLAGS) $(OBJ) $(TVM_OBJ) $(LIB_X86_AVX_OBJ) -o $@.elf $(CXXFLAGS) $(TVM_X86_AVX_LIBS) -pthread -ldl

$(LIBS_X86): $(LIB_OBJ) $(TVM_OBJ) $(LIB_X86_OBJ)
$(CXX) $(FLAGS) $(LIB_OBJ) $(TVM_OBJ) $(LIB_X86_OBJ) -o $@.so $(CXXFLAGS) $(TVM_X86_LIBS) -shared -pthread -ldl

$(LIBS_X86_AVX): $(LIB_OBJ) $(TVM_OBJ) $(LIB_X86_AVX_OBJ)
$(CXX) $(FLAGS) $(LIB_OBJ) $(TVM_OBJ) $(LIB_X86_AVX_OBJ) -o $@.so $(CXXFLAGS) $(TVM_X86_AVX_LIBS) -shared -pthread -ldl

$(LIBS_AARCH64): $(LIB_OBJ) $(TVM_OBJ) $(LIB_AARCH64_OBJ)
$(CXX) $(FLAGS) $(LIB_OBJ) $(TVM_OBJ) $(LIB_AARCH64_OBJ) -o $@.so $(CXXFLAGS) $(TVM_AARCH64_LIBS) -shared -pthread -ldl

Expand All @@ -216,6 +251,9 @@ $(LIBS_FPGA): $(LIB_OBJ) $(TVM_OBJ) $(LIB_FPGA_OBJ)
$(ARS_X86): $(LIB_OBJ) $(TVM_OBJ) $(LIB_X86_OBJ)
$(AR) $(LDFLAGS) libdlk_$(NAME).a $(LIB_OBJ) $(TVM_OBJ) $(TVM_X86_LIBS) $(LIB_X86_OBJ)

$(ARS_X86_AVX): $(LIB_OBJ) $(TVM_OBJ) $(LIB_X86_AVX_OBJ)
$(AR) $(LDFLAGS) libdlk_$(NAME).a $(LIB_OBJ) $(TVM_OBJ) $(TVM_X86_AVX_LIBS) $(LIB_X86_AVX_OBJ)

$(ARS_AARCH64): $(LIB_OBJ) $(TVM_OBJ) $(LIB_AARCH64_OBJ)
$(AR) $(LDFLAGS) libdlk_$(NAME).a $(LIB_OBJ) $(TVM_OBJ) $(TVM_AARCH64_LIBS) $(LIB_AARCH64_OBJ)

Expand Down
2 changes: 1 addition & 1 deletion dlk/python/dlk/templates/include/func/quantized_conv2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ void QuantizedConv2D(const TensorView<T, layout>& input,
convert_tensor(input, tmp);
Measurement::Stop();
dlk::impl::TCAConv2d(tmp, kernel, p);
#elif defined USE_NEON
#elif defined USE_NEON || defined USE_AVX
dlk::impl::tiling_input_t::tensor_info_t<std::size_t> shape = {
ic / TilingInTypeBitWidth,
ih,
Expand Down
2 changes: 1 addition & 1 deletion dlk/python/dlk/templates/include/tensor_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ class TensorView<QuantizedPacked<T>, memory_layout> {

#ifdef RUN_ON_FPGA
using kernel_t = TensorView<QUANTIZED_PACKED_KERNEL, MemoryLayout::OhIhHWOlIl>;
#elif defined USE_NEON
#elif defined USE_NEON || defined USE_AVX
using kernel_t = TensorView<QUANTIZED_PACKED_KERNEL, MemoryLayout::NHWC>;
#else
using kernel_t = TensorView<QUANTIZED_PACKED_KERNEL, MemoryLayout::HWNC>;
Expand Down
2 changes: 1 addition & 1 deletion dlk/python/dlk/templates/manual/consts/input.tpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ static constexpr decltype({{ node.name }})::tensor_info_t<std::size_t> {{ node.n
const TensorView<{{ node.dtype.cpptype() }}, MemoryLayout::{{ node.transposed_dimension_format }}> {{ node.name }}(
reinterpret_cast<{{ node.dtype.cpptype() }}*>({{ node.name }}_raw),
{{ node.name }}_shape);
#elif defined USE_NEON
#elif defined USE_NEON || defined USE_AVX
static Base<{{ node.dtype.cpptype() }}>::type {{ node.name }}_raw[] = {
{% for d in node.data.flatten() -%}
{{- d -}},
Expand Down
2 changes: 1 addition & 1 deletion dlk/python/dlk/templates/manual/consts/input.tpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ extern const TensorView<{{ node.dtype.cpptype() }}, MemoryLayout::Atom> {{ node.

#ifdef RUN_ON_FPGA
extern const TensorView<{{ node.dtype.cpptype() }}, MemoryLayout::{{ node.transposed_dimension_format }}> {{ node.name }};
#elif defined USE_NEON
#elif defined USE_NEON || defined USE_AVX
extern const TensorView<{{ node.dtype.cpptype() }}, MemoryLayout::{{ node.dimension}}> {{ node.name }};
#else
extern const TensorView<{{ node.dtype.cpptype() }}, MemoryLayout::{{ node.kn2row_dimension_format }}> {{ node.name }};
Expand Down
139 changes: 139 additions & 0 deletions dlk/python/dlk/templates/src/func/impl/x86_avx/apply_thresholds.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/* Copyright 2018 The Blueoil Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "global.h"
#include "matrix_view.h"
#include "operators.h" // FIXME(nikolay): for binary_convolution_parameters definition, rid of it later
#include "time_measurement.h"

#include <x86intrin.h>

namespace dlk {

namespace impl {

void ApplyThresholds(
dlk::MatrixView<BIN_CONV_OUTPUT, dlk::MatrixOrder::ColMajor> &result,
const binary_convolution_parameters &p) {
Measurement::Start("ApplyThresholds");

const auto buf_ts0 = std::make_unique<BIN_CONV_OUTPUT[]>(result.rows());
const auto buf_ts1 = std::make_unique<BIN_CONV_OUTPUT[]>(result.rows());
const auto buf_ts2 = std::make_unique<BIN_CONV_OUTPUT[]>(result.rows());
const auto buf_flg = std::make_unique<BIN_CONV_OUTPUT[]>(result.rows());

for (unsigned int i = 0; i < result.rows(); ++i) {
T_INT ts0 = p.thresholds[NUM_OF_A2W1_THRESHOLD * i];
T_INT ts1 = p.thresholds[NUM_OF_A2W1_THRESHOLD * i + 1];
T_INT ts2 = p.thresholds[NUM_OF_A2W1_THRESHOLD * i + 2];
T_INT flag = p.thresholds[NUM_OF_A2W1_THRESHOLD * i + 3];
if (flag == -1) {
++ts0;
++ts1;
++ts2;
}
buf_ts0[i] = ts0;
buf_ts1[i] = ts1;
buf_ts2[i] = ts2;
buf_flg[i] = flag;
}

#pragma omp parallel for
for (unsigned int j = 0; j < result.cols(); ++j) {
for (unsigned int i = 0; i < result.rows(); i += 16) {
const auto d = _mm256_loadu_si256(reinterpret_cast<__m256i*>(result.data(i, j)));
const auto ts0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(buf_ts0.get() + i));
const auto ts1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(buf_ts1.get() + i));
const auto ts2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(buf_ts2.get() + i));
const auto flg = _mm256_loadu_si256(reinterpret_cast<__m256i*>(buf_flg.get() + i));
const auto f0 = _mm256_andnot_si256(_mm256_cmpgt_epi16(ts0, d), flg);
const auto f1 = _mm256_andnot_si256(_mm256_cmpgt_epi16(ts1, d), flg);
const auto f2 = _mm256_andnot_si256(_mm256_cmpgt_epi16(ts2, d), flg);
const auto is_neg = _mm256_cmpgt_epi16(_mm256_setzero_si256(), flg);
const auto tmp = _mm256_add_epi16(_mm256_add_epi16(f0, f1), _mm256_add_epi16(f2, is_neg));
const auto m2 = _mm256_sub_epi16(flg, _mm256_set1_epi16(2));
const auto is_not_const = _mm256_cmpgt_epi16(_mm256_setzero_si256(), m2);
const auto res = _mm256_blendv_epi8(m2, tmp, is_not_const);
_mm256_storeu_si256(reinterpret_cast<__m256i*>(result.data(i, j)), res);
}
}

Measurement::Stop();
}

void ApplyThresholdsAndPack(
dlk::MatrixView<BIN_CONV_OUTPUT, dlk::MatrixOrder::ColMajor> &result,
const binary_convolution_parameters &p,
QUANTIZED_PACKED output[]) {
Measurement::Start("ApplyThresholdsAndPack");

const auto buf_ts0 = std::make_unique<BIN_CONV_OUTPUT[]>(result.rows());
const auto buf_ts1 = std::make_unique<BIN_CONV_OUTPUT[]>(result.rows());
const auto buf_ts2 = std::make_unique<BIN_CONV_OUTPUT[]>(result.rows());
const auto buf_flg = std::make_unique<BIN_CONV_OUTPUT[]>(result.rows());

for (unsigned int i = 0; i < result.rows(); ++i) {
T_INT ts0 = p.thresholds[NUM_OF_A2W1_THRESHOLD * i];
T_INT ts1 = p.thresholds[NUM_OF_A2W1_THRESHOLD * i + 1];
T_INT ts2 = p.thresholds[NUM_OF_A2W1_THRESHOLD * i + 2];
T_INT flag = p.thresholds[NUM_OF_A2W1_THRESHOLD * i + 3];
if (flag == -1) {
++ts0;
++ts1;
++ts2;
}
buf_ts0[i] = ts0;
buf_ts1[i] = ts1;
buf_ts2[i] = ts2;
buf_flg[i] = flag;
}

#pragma omp parallel for
for (unsigned int j = 0; j < result.cols(); ++j) {
for (unsigned int i = 0; i < result.rows(); i += 32) {
#define APPLY(k) \
const auto d##k = _mm256_loadu_si256(reinterpret_cast<__m256i*>(result.data(i + k * 16, j))); \
const auto ts0##k = _mm256_loadu_si256(reinterpret_cast<__m256i*>(buf_ts0.get() + i + k * 16)); \
const auto ts1##k = _mm256_loadu_si256(reinterpret_cast<__m256i*>(buf_ts1.get() + i + k * 16)); \
const auto ts2##k = _mm256_loadu_si256(reinterpret_cast<__m256i*>(buf_ts2.get() + i + k * 16)); \
const auto flg##k = _mm256_loadu_si256(reinterpret_cast<__m256i*>(buf_flg.get() + i + k * 16)); \
const auto f0##k = _mm256_andnot_si256(_mm256_cmpgt_epi16(ts0##k, d##k), flg##k); \
const auto f1##k = _mm256_andnot_si256(_mm256_cmpgt_epi16(ts1##k, d##k), flg##k); \
const auto f2##k = _mm256_andnot_si256(_mm256_cmpgt_epi16(ts2##k, d##k), flg##k); \
const auto is_neg##k = _mm256_cmpgt_epi16(_mm256_setzero_si256(), flg##k); \
const auto tmp##k = _mm256_add_epi16(_mm256_add_epi16(f0##k, f1##k), _mm256_add_epi16(f2##k, is_neg##k)); \
const auto m2##k = _mm256_sub_epi16(flg##k, _mm256_set1_epi16(2)); \
const auto is_not_const##k = _mm256_cmpgt_epi16(_mm256_setzero_si256(), m2##k); \
const auto res##k = _mm256_blendv_epi8(m2##k, tmp##k, is_not_const##k);
APPLY(0)
APPLY(1)
const auto packed = _mm256_packs_epi16(res0, res1);
const auto permuted = _mm256_permute4x64_epi64(packed, 0xD8);
const auto vlsb = _mm256_slli_epi32(permuted, 7);
const auto vmsb = _mm256_slli_epi32(permuted, 6);
const auto lsb = _mm256_movemask_epi8(vlsb);
const auto msb = _mm256_movemask_epi8(vmsb);
const auto index = (j + (i / 32) * result.cols()) * 2;
output[index] = QUANTIZED_PACKED(lsb);
output[index+1] = QUANTIZED_PACKED(msb);
}
}

Measurement::Stop();
}

} // namespace impl

} // namespace dlk
51 changes: 51 additions & 0 deletions dlk/python/dlk/templates/src/func/impl/x86_avx/pack_16bit.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/* Copyright 2019 The Blueoil Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "func/impl/pack_16bit.h"
#include <cassert>

#include <x86intrin.h>

#include "time_measurement.h"

namespace dlk {

namespace impl {

void pack_16bit(const BIN_CONV_OUTPUT input[], QUANTIZED_PACKED output[], const std::size_t length) {
using base = QUANTIZED_PACKED::base_t;
const auto bits = QUANTIZED_PACKED::BitCount;
assert((length % bits) == 0);
Measurement::Start("pack bits");
std::size_t j = 0;
for (std::size_t i = 0; i < length; i += bits) {
const auto v1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(input + i));
const auto v2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(input + i + 16));
const auto packed = _mm256_packs_epi16(v1, v2);
const auto permuted = _mm256_permute4x64_epi64(packed, 0xD8);
const auto vlsb = _mm256_slli_epi32(permuted, 7);
const auto vmsb = _mm256_slli_epi32(permuted, 6);
const auto lsb = _mm256_movemask_epi8(vlsb);
const auto msb = _mm256_movemask_epi8(vmsb);
output[j] = QUANTIZED_PACKED(lsb);
output[j+1] = QUANTIZED_PACKED(msb);
j += 2;
}
Measurement::Stop();
}

} // namespace impl

} // namespace dlk
Loading

0 comments on commit 5abf157

Please sign in to comment.