diff --git a/.gitignore b/.gitignore index f5a8e060871..920cad42b6d 100644 --- a/.gitignore +++ b/.gitignore @@ -144,6 +144,7 @@ GSYMS /tools/mmseg-1.3.0.tar.gz /tools/mmseg-1.3.0/ /kaldiwin_vs* +/tools/cudnn/ /tools/cub-1.8.0.zip /tools/cub-1.8.0/ /tools/cub diff --git a/egs/mini_librispeech/s5/run.sh b/egs/mini_librispeech/s5/run.sh index 681859edf8a..997557f7904 100755 --- a/egs/mini_librispeech/s5/run.sh +++ b/egs/mini_librispeech/s5/run.sh @@ -48,7 +48,7 @@ if [ $stage -le 2 ]; then # spread the mfccs over various machines, as this data-set is quite large. if [[ $(hostname -f) == *.clsp.jhu.edu ]]; then mfcc=$(basename mfccdir) # in case was absolute pathname (unlikely), get basename. - utils/create_split_dir.pl /export/b{07,14,16,17}/$USER/kaldi-data/egs/librispeech/s5/$mfcc/storage \ + utils/create_split_dir.pl /export/b{07,14,16,18}/$USER/kaldi-data/egs/librispeech/s5/$mfcc/storage \ $mfccdir/storage fi @@ -170,9 +170,6 @@ if [ $stage -le 7 ]; then utils/build_const_arpa_lm.sh \ data/local/lm/lm_tglarge.arpa.gz data/lang data/lang_test_tglarge - - steps/align_fmllr.sh --nj 5 --cmd "$train_cmd" \ - data/train_clean_5 data/lang exp/tri3b exp/tri3b_ali_train_clean_5 fi diff --git a/src/configure b/src/configure index c4a1445efbd..543bd6ebabf 100755 --- a/src/configure +++ b/src/configure @@ -67,6 +67,7 @@ Configuration options: --shared Build and link against shared libraries [default=no] --use-cuda Build with CUDA [default=yes] --cudatk-dir=DIR CUDA toolkit directory + --cudnn-dir=DIR CUDNN installation directory --cuda-arch=FLAGS Override the default CUDA_ARCH flags. See https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#nvcc-examples. --double-precision Build with BaseFloat set to double if yes [default=no], mostly useful for testing purposes. @@ -431,9 +432,6 @@ function configure_cuda { if [ -z "$CUDA_ARCH" ]; then case $CUDA_VERSION in - 5_5) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35" ;; - 6_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50" ;; - 7_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_53,code=sm_53" ;; 8_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62" ;; 9_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70" ;; 10_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_72,code=sm_72 -gencode arch=compute_75,code=sm_75" ;; @@ -450,6 +448,8 @@ function configure_cuda { echo "CUDA_ARCH = $CUDA_ARCH" >> kaldi.mk echo >> kaldi.mk + configure_cudnn + # 64bit/32bit? We do not support cross compilation with CUDA so, use direct calls to uname -m here if [ "`uname -m`" == "x86_64" ]; then if [ "`uname`" == "Darwin" ]; then @@ -462,7 +462,7 @@ function configure_cuda { elif [ "`uname -m`" == "ppc64le" ]; then cat makefiles/cuda_64bit.mk >> kaldi.mk else - cat makefiles/cuda_32bit.mk >> kaldi.mk + echo "Unexpected architecture `uname -m`"; exit 1 fi else @@ -472,6 +472,47 @@ function configure_cuda { fi } +function configure_cudnn { + if [ -z $CUDNNDIR ]; then + download_appropriate_cudnn + fi + + echo CUDNNDIR = $CUDNNDIR >> kaldi.mk + echo >> kaldi.mk + + if [ ! -f $CUDNNDIR/lib64/libcudnn.so ] | + [ ! -f $CUDNNDIR/include/cudnn.h ]; then + echo "CUDNNDIR(=$CUDNNDIR) invalid!" + fi +} + +function download_appropriate_cudnn { + local tools=`rel2abs ../tools` + install_dir=$tools/cudnn + CUDNNDIR=$tools/cudnn/cuda + + if [ -f $CUDNNDIR/include/cudnn.h ]; then + echo -n "CUDNN has been downloaded already. If you'd like to redownload it " + echo -n "(e.g., because you changed CUDA version), please delete $CUDNNDIR " + echo "and rerun configure" + return + fi + + local cudnn_url + case $CUDA_VERSION in + 8_0) cudnn_url="http://developer.download.nvidia.com/compute/redist/cudnn/v7.1.2/cudnn-8.0-linux-x64-v7.1.tgz" ;; + 9_0) cudnn_url="http://developer.download.nvidia.com/compute/redist/cudnn/v7.3.1/cudnn-9.0-linux-x64-v7.3.1.20.tgz" ;; + 9_1) cudnn_url="http://developer.download.nvidia.com/compute/redist/cudnn/v7.1.2/cudnn-9.1-linux-x64-v7.1.tgz" ;; + 9_2) cudnn_url="http://developer.download.nvidia.com/compute/redist/cudnn/v7.2.1/cudnn-9.2-linux-x64-v7.2.1.38.tgz" ;; + 10_0) cudnn_url="http://developer.download.nvidia.com/compute/redist/cudnn/v7.3.1/cudnn-10.0-linux-x64-v7.3.1.20.tgz" ;; + *) echo "No known CUDNN download for provided CUDA_VERSION. Try checking here to see if your CUDA version supports a reasonably new version of CUDNN: https://gitlab.com/nvidia/cuda/tree/centos7"; exit 1 ;; + esac + + mkdir -p $install_dir + wget -T 10 -t 3 $cudnn_url -O $install_dir/cudnn.tgz + tar --no-same-owner -xzf $install_dir/cudnn.tgz -C $install_dir +} + function linux_configure_speex { # Check whether the user has called tools/extras/install_speex.sh or not [ ! -z "$SPEEXROOT" ] || SPEEXROOT=`pwd`/../tools/speex @@ -989,6 +1030,9 @@ do --cudatk-dir=*) CUDATKDIR=`read_dirname $1`; shift ;; #CUDA is used in src/cudamatrix and src/nnet{,bin} only + --cudnn-dir=*) + CUDNNDIR=`read_dirname $1`; + shift ;; --cuda-arch=*) CUDA_ARCH=`read_value $1`; shift;; diff --git a/src/cudamatrix/cu-common.h b/src/cudamatrix/cu-common.h index 7446a76bf93..42a0a0347d2 100644 --- a/src/cudamatrix/cu-common.h +++ b/src/cudamatrix/cu-common.h @@ -74,6 +74,15 @@ } \ } +#define CUDNN_SAFE_CALL(fun) \ +do { \ + cudnnStatus_t ret; \ + if ((ret = (fun)) != CUDNN_STATUS_SUCCESS) { \ + KALDI_ERR << "cudnnStatus_t " << ret << " : \"" << cudnnGetErrorString(ret) \ + << "\" returned from '" << #fun << "'"; \ + } \ +} while(0) + namespace kaldi { diff --git a/src/cudamatrix/cu-device.cc b/src/cudamatrix/cu-device.cc index 49c179b3673..d31cdeb82a5 100644 --- a/src/cudamatrix/cu-device.cc +++ b/src/cudamatrix/cu-device.cc @@ -4,6 +4,7 @@ // 2013 Lucas Ondel // 2013-2015 Johns Hopkins University (author: Daniel Povey) // 2015 Guoguo Chen +// 2018 Daniel Galvez // See ../../COPYING for clarification regarding multiple authors // @@ -92,7 +93,8 @@ void CuDevice::Initialize() { // // (2) in threads created by the user, as soon as someone calls something that // might potentially use the GPU, via CuDevice()::Instantiate(). - // If device_id_ is >= 0, this will create the cuBLAS and cuSparse handles. + // If device_id_ is >= 0, this will create the cuBLAS, cuSparse, cuDNN + // handles. KALDI_ASSERT(!initialized_); initialized_ = true; if (device_id_ == -1) { @@ -113,6 +115,8 @@ void CuDevice::Initialize() { // Initialize the cuSPARSE library CUSPARSE_SAFE_CALL(cusparseCreate(&cusparse_handle_)); CUSPARSE_SAFE_CALL(cusparseSetStream(cusparse_handle_, cudaStreamPerThread)); + CUDNN_SAFE_CALL(cudnnCreate(&cudnn_handle_)); + CUDNN_SAFE_CALL(cudnnSetStream(cudnn_handle_, cudaStreamPerThread)); } } @@ -248,8 +252,10 @@ void CuDevice::FinalizeActiveGpu() { // Initialize the cuSPARSE library CUSPARSE_SAFE_CALL(cusparseCreate(&cusparse_handle_)); CUSPARSE_SAFE_CALL(cusparseSetStream(cusparse_handle_, cudaStreamPerThread)); + CUDNN_SAFE_CALL(cudnnCreate(&cudnn_handle_)); + CUDNN_SAFE_CALL(cudnnSetStream(cudnn_handle_, cudaStreamPerThread)); - // Notify the user which GPU is being userd. + // Notify the user which GPU is being used. char name[128]; DeviceGetName(name,128, device_id); @@ -511,7 +517,8 @@ CuDevice::CuDevice(): initialized_(false), device_id_copy_(-1), cublas_handle_(NULL), - cusparse_handle_(NULL) { + cusparse_handle_(NULL), + cudnn_handle_(NULL) { } CuDevice::~CuDevice() { @@ -519,6 +526,8 @@ CuDevice::~CuDevice() { CUBLAS_SAFE_CALL(cublasDestroy(cublas_handle_)); if (cusparse_handle_) CUSPARSE_SAFE_CALL(cusparseDestroy(cusparse_handle_)); + if (cudnn_handle_) + CUDNN_SAFE_CALL(cudnnDestroy(cudnn_handle_)); } diff --git a/src/cudamatrix/cu-device.h b/src/cudamatrix/cu-device.h index dc3df7e347d..95c447f3a7b 100644 --- a/src/cudamatrix/cu-device.h +++ b/src/cudamatrix/cu-device.h @@ -2,6 +2,7 @@ // Copyright 2009-2012 Karel Vesely // 2012-2015 Johns Hopkins University (author: Daniel Povey) +// 2018 Daniel Galvez // See ../../COPYING for clarification regarding multiple authors // @@ -30,6 +31,7 @@ #include #include #include +#include #include #include "base/kaldi-common.h" #include "base/timer.h" @@ -80,6 +82,7 @@ class CuDevice { inline cublasHandle_t GetCublasHandle() { return cublas_handle_; } inline cusparseHandle_t GetCusparseHandle() { return cusparse_handle_; } + inline cudnnHandle_t GetCudnnHandle() { return cudnn_handle_; } // We provide functions Malloc(), MallocPitch() and Free() which replace // cudaMalloc(), cudaMallocPitch() and cudaFree(). Their function is to cache @@ -271,6 +274,8 @@ class CuDevice { cusparseHandle_t cusparse_handle_; + cudnnHandle_t cudnn_handle_; + }; // class CuDevice @@ -289,6 +294,8 @@ inline cublasHandle_t GetCublasHandle() { return CuDevice::Instantiate().GetCubl // A more convenient way to get the handle to use cuSPARSE APIs. inline cusparseHandle_t GetCusparseHandle() { return CuDevice::Instantiate().GetCusparseHandle(); } +inline cudnnHandle_t GetCudnnHandle() { return CuDevice::Instantiate().GetCudnnHandle(); } + } // namespace kaldi diff --git a/src/ivector/ivector-extractor.h b/src/ivector/ivector-extractor.h index 3b9b6f3eb5c..938034859e2 100644 --- a/src/ivector/ivector-extractor.h +++ b/src/ivector/ivector-extractor.h @@ -468,7 +468,7 @@ struct IvectorExtractorEstimationOptions { "update any associated parameters."); opts->Register("diagonalize", &diagonalize, "If true, diagonalize the quadratic term in the " - "objective function. This reorders the ivector dimensions" + "objective function. This reorders the ivector dimensions " "from most to least important."); } }; diff --git a/src/makefiles/cuda_32bit.mk b/src/makefiles/cuda_32bit.mk deleted file mode 100644 index 9dcc965d7b6..00000000000 --- a/src/makefiles/cuda_32bit.mk +++ /dev/null @@ -1,15 +0,0 @@ -ifndef DOUBLE_PRECISION -$(error DOUBLE_PRECISION not defined.) -endif -ifndef CUDATKDIR -$(error CUDATKDIR not defined.) -endif - -CUDA_INCLUDE= -I$(CUDATKDIR)/include -I$(CUBROOT) -CUDA_FLAGS = -Xcompiler "-fPIC -pthread -isystem $(OPENFSTINC)" --verbose --machine 32 -DHAVE_CUDA \ - -ccbin $(CXX) -DKALDI_DOUBLEPRECISION=$(DOUBLE_PRECISION) \ - -std=c++11 -DCUDA_API_PER_THREAD_DEFAULT_STREAM - -CXXFLAGS += -DHAVE_CUDA -I$(CUDATKDIR)/include -CUDA_LDFLAGS += -L$(CUDATKDIR)/lib -Wl,-rpath,$(CUDATKDIR)/lib -CUDA_LDLIBS += -lcublas -lcusparse -lcudart -lcurand -lnvToolsExt #LDLIBS : The libs are loaded later than static libs in implicit rule diff --git a/src/makefiles/cuda_64bit.mk b/src/makefiles/cuda_64bit.mk index 72314d45622..232d3456022 100644 --- a/src/makefiles/cuda_64bit.mk +++ b/src/makefiles/cuda_64bit.mk @@ -4,12 +4,25 @@ endif ifndef CUDATKDIR $(error CUDATKDIR not defined.) endif +ifndef CUDNNDIR +$(error CUDNNDIR not defined.) +endif + +# Order matters here. We must tell the compiler to search +# $(CUDNNDIR)/lib64 before $(CUDATKDIR)/lib64 because the CUDNN .deb +# files install cudnn to /usr/local/cuda/lib64, which would overshadow +# the user-specified $(CUDNNDIR) +CUDA_INCLUDE += -I$(CUDNNDIR)/include +CXXFLAGS += -I$(CUDNNDIR)/include +CUDA_LDFLAGS += -L$(CUDNNDIR)/lib64 -Wl,-rpath,$(CUDNNDIR)/lib64 +CUDA_LDLIBS += -lcudnn CUDA_INCLUDE= -I$(CUDATKDIR)/include -I$(CUBROOT) -CUDA_FLAGS = -Xcompiler "-fPIC -pthread -isystem $(OPENFSTINC)" --verbose --machine 64 -DHAVE_CUDA \ +CUDA_FLAGS = -Xcompiler "-fPIC -pthread -isystem $(OPENFSTINC)" --verbose --machine 64 -DHAVE_CUDA=1 \ -ccbin $(CXX) -DKALDI_DOUBLEPRECISION=$(DOUBLE_PRECISION) \ - -std=c++11 -DCUDA_API_PER_THREAD_DEFAULT_STREAM + -std=c++11 -DCUDA_API_PER_THREAD_DEFAULT_STREAM -I$(CUDATKDIR)/include CXXFLAGS += -DHAVE_CUDA -I$(CUDATKDIR)/include + CUDA_LDFLAGS += -L$(CUDATKDIR)/lib64 -Wl,-rpath,$(CUDATKDIR)/lib64 CUDA_LDLIBS += -lcublas -lcusparse -lcudart -lcurand -lnvToolsExt #LDLIBS : The libs are loaded later than static libs in implicit rule diff --git a/src/matrix/kaldi-matrix.cc b/src/matrix/kaldi-matrix.cc index fcfe0616b64..df3aea3fe2c 100644 --- a/src/matrix/kaldi-matrix.cc +++ b/src/matrix/kaldi-matrix.cc @@ -28,7 +28,7 @@ #include "matrix/compressed-matrix.h" #include "matrix/sparse-matrix.h" -static_assert(int(kaldi::kNoTrans) == int(CblasNoTrans) && int(kaldi::kTrans) == int(CblasTrans), +static_assert(int(kaldi::kNoTrans) == int(CblasNoTrans) && int(kaldi::kTrans) == int(CblasTrans), "kaldi::kNoTrans and kaldi::kTrans must be equal to the appropriate CBLAS library constants!"); namespace kaldi { @@ -538,7 +538,7 @@ void MatrixBase::AddMatSmat(Real alpha, const MatrixBase &A, // pass stride to write a column as matrices are stored in row major order. cblas_Xaxpy(this_num_rows, alpha_B_jk, a_col_k, A.stride_, this_col_j, this->stride_); - //for (MatrixIndexT i = 0; i < this_num_rows; ++i) + //for (MatrixIndexT i = 0; i < this_num_rows; ++i) // this_col_j[i*this->stride_] += alpha_B_jk * a_col_k[i*A.stride_]; } } @@ -1656,11 +1656,12 @@ SubMatrix::SubMatrix(const MatrixBase &M, template -SubMatrix::SubMatrix(Real *data, +SubMatrix::SubMatrix(const Real *data, MatrixIndexT num_rows, MatrixIndexT num_cols, MatrixIndexT stride): - MatrixBase(data, num_cols, num_rows, stride) { // caution: reversed order! + MatrixBase(const_cast(data), + num_cols, num_rows, stride) { // caution: reversed order! if (data == NULL) { KALDI_ASSERT(num_rows * num_cols == 0); this->num_rows_ = 0; @@ -1839,7 +1840,8 @@ void MatrixBase::Svd(VectorBase *s, MatrixBase *U, MatrixBase< KALDI_ERR << "Error doing Svd (did not converge), first part of matrix is\n" << SubMatrix(*this, 0, std::min((MatrixIndexT)10, num_rows_), 0, std::min((MatrixIndexT)10, num_cols_)) - << ", min and max are: " << Min() << ", " << Max(); + << ", min, max and sum are: " << Min() << ", " << Max() + << ", " << Sum(); } } diff --git a/src/matrix/kaldi-matrix.h b/src/matrix/kaldi-matrix.h index a973824128c..a67f75260bf 100644 --- a/src/matrix/kaldi-matrix.h +++ b/src/matrix/kaldi-matrix.h @@ -952,9 +952,10 @@ class SubMatrix : public MatrixBase { const MatrixIndexT co, // column offset, 0 < co < NumCols() const MatrixIndexT c); // number of columns, c > 0 - // This initializer is mostly intended for use in CuMatrix and related - // classes. Be careful! - SubMatrix(Real *data, + // This initializer does not take ownership of the pointer, and to use it you + // need to have some understanding of how this library works. Caution: + // it can be used to get around const limitations, so be careful. + SubMatrix(const Real *data, MatrixIndexT num_rows, MatrixIndexT num_cols, MatrixIndexT stride); diff --git a/src/nnet3/Makefile b/src/nnet3/Makefile index aac16fb1c86..710c019cd56 100644 --- a/src/nnet3/Makefile +++ b/src/nnet3/Makefile @@ -12,7 +12,8 @@ TESTFILES = natural-gradient-online-test nnet-graph-test \ nnet-compile-utils-test nnet-nnet-test nnet-utils-test \ nnet-compile-test nnet-analyze-test nnet-compute-test \ nnet-optimize-test nnet-derivative-test nnet-example-test \ - nnet-common-test convolution-test attention-test + nnet-common-test convolution-test attention-test \ + convolution-cudnn-test OBJFILES = nnet-common.o nnet-compile.o nnet-component-itf.o \ nnet-simple-component.o nnet-normalize-component.o \ @@ -31,9 +32,14 @@ OBJFILES = nnet-common.o nnet-compile.o nnet-component-itf.o \ nnet-compile-looped.o decodable-simple-looped.o \ decodable-online-looped.o convolution.o \ nnet-convolutional-component.o attention.o \ - nnet-attention-component.o nnet-tdnn-component.o nnet-batch-compute.o + nnet-attention-component.o nnet-tdnn-component.o \ + nnet-batch-compute.o convolution-cudnn.o +ifeq ($(CUDA), true) +OBJFILES += convolution-cudnn.o +endif + LIBNAME = kaldi-nnet3 ADDLIBS = ../chain/kaldi-chain.a ../cudamatrix/kaldi-cudamatrix.a \ @@ -41,6 +47,6 @@ ADDLIBS = ../chain/kaldi-chain.a ../cudamatrix/kaldi-cudamatrix.a \ ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ - ../base/kaldi-base.a + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/nnet3/convolution-cudnn-test.cc b/src/nnet3/convolution-cudnn-test.cc new file mode 100644 index 00000000000..b51cb04ce82 --- /dev/null +++ b/src/nnet3/convolution-cudnn-test.cc @@ -0,0 +1,591 @@ +// nnet3/convolution-cudnn-test.cc + +// Copyright 2018 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "nnet3/convolution-cudnn.h" +#include "util/common-utils.h" + +namespace kaldi { +namespace nnet3 { +namespace cudnn_convolution { + +// for testing purposes, create a random ConvolutionComputation +static void GetRandomConfig(ConvolutionComputationConfig *config) { + config->num_images = RandInt(1, 10); + config->num_channels_out = RandInt(1, 10); + config->num_channels_in = RandInt(1, 10); + + config->filter_height = RandInt(1, 3); + config->filter_width = RandInt(1, 3); + + config->filter_stride_vertical = RandInt(1, 2); + config->filter_stride_horizontal = RandInt(1, 2); + config->filter_dilation_vertical = RandInt(1, 2); + config->filter_dilation_horizontal = RandInt(1, 2); + + config->input_image_height = RandInt(10, 20); + config->input_image_width = RandInt(10, 20); + + config->zero_padding_vertical = RandInt(0, 1); + config->zero_padding_horizontal = RandInt(0, 1); + + config->Check(); + config->ComputeOutputImageSize(); +} + +void TestConvolutionComputationConfig() { + for (int32 i = 0; i < 100; i++) { + ConvolutionComputationConfig config; + GetRandomConfig(&config); + std::ostringstream os; + bool binary = true; + config.Write(os, binary); + + ConvolutionComputationConfig config2; + std::istringstream is(os.str()); + config2.Read(is, binary); + std::ostringstream os2; + config2.Write(os2, binary); + KALDI_ASSERT(os.str() == os2.str()); + } +} + + +void ConvolveForwardWithGpu( + const ConvolutionComputation &computation, + const MatrixBase &input, + const MatrixBase ¶ms, + const VectorBase &bias, + MatrixBase *output) { + CuMatrix input_gpu(input.NumRows(), input.NumCols(), + kUndefined, kStrideEqualNumCols); + input_gpu.CopyFromMat(input); + CuMatrix params_gpu(params.NumRows(), params.NumCols(), + kUndefined, kStrideEqualNumCols); + params_gpu.CopyFromMat(params); + CuVector bias_gpu(bias); + CuMatrix output_gpu(output->NumRows(), output->NumCols(), + kSetZero, kStrideEqualNumCols); + computation.ConvolveForward(input_gpu, params_gpu, &bias_gpu, &output_gpu); + output->CopyFromMat(output_gpu); +} + + + +// Tests that the CPU version gives the expected results. The output +// has to be inspected by a human. +void TestConvolutionComputationForward( + const ConvolutionComputation &computation, + bool use_gpu) { + const ConvolutionComputationConfig &c = computation.Config(); + + Matrix input(c.num_images * c.input_image_width, + c.input_image_height * c.num_channels_in, + kSetZero, kStrideEqualNumCols), + output(c.num_images * c.output_image_width, + c.output_image_height * c.num_channels_out, + kSetZero, kStrideEqualNumCols), + params(c.num_channels_out, + c.num_channels_in * c.filter_width * c.filter_height, + kSetZero, kStrideEqualNumCols); + + // One parameter and one channel of input pixel will be nonzero-- for testing purposes. + + int32 n = RandInt(0, c.num_images - 1), + input_w = RandInt(0, c.input_image_width - 1), + input_h = RandInt(0, c.input_image_height - 1), + input_c = RandInt(0, c.num_channels_in - 1); + input(n * c.input_image_width + input_w, + input_h * c.num_channels_in + input_c) = 2.0; + + int32 output_c = RandInt(0, c.num_channels_out - 1), + filter_w = RandInt(0, c.filter_width - 1), + filter_h = RandInt(0, c.filter_height - 1); + + params(output_c, + input_c * c.filter_width * c.filter_height + + filter_w * c.filter_height + + filter_h) = 3.0; + + Vector bias(c.num_channels_out); + + if (use_gpu) { + ConvolveForwardWithGpu(computation, input, params, bias, &output); + } else { + computation.ConvolveForward(input, params, &bias, + &output); + } + + KALDI_LOG << "Have nonzero input for n=" << n + << ", w=" << input_w << ", h=" << input_h + << ", input_channel=" << input_c; + KALDI_LOG << "Have nonzero filter for w=" + << filter_w << ", h=" << filter_h + << ", output_channel=" << output_c; + bool found_nonzero = false; + for (int32 n = 0; n < c.num_images; n++) { + for (int32 w = 0; w < c.output_image_width; w++) { + for (int32 h = 0; h < c.output_image_height; h++) { + for (int32 ch = 0; ch < c.num_channels_out; ch++) { + BaseFloat val = output(n * c.output_image_width + w, + h * c.num_channels_out + ch); + if (val != 0.0) { + found_nonzero = true; + KALDI_LOG << "Found nonzero value " << val << " for image n=" + << n << ", w=" << w << ", h=" << h + << ", output_channel=" << ch; + } + } + } + } + } + if (!found_nonzero) + KALDI_WARN << "Found no nonzero value, sum is " << output.Sum(); + + +} + +void TestConvolutionComputation() { + for (int32 i = 0; i < 100; i++) { + ConvolutionComputationConfig config; + GetRandomConfig(&config); + + { + std::ostringstream os; + config.Write(os, false); + KALDI_LOG << "Config is: " << os.str(); + } + + ConvolutionComputation computation(config); + + std::ostringstream os; + bool binary = true; + computation.Write(os, binary); + + ConvolutionComputation computation2; + std::istringstream is(os.str()); + computation2.Read(is, binary); + std::ostringstream os2; + computation2.Write(os2, binary); + KALDI_ASSERT(os.str() == os2.str()); + KALDI_LOG << "About to test without GPU."; + TestConvolutionComputationForward(computation2, false); +#if HAVE_CUDA == 1 + if (CuDevice::Instantiate().Enabled()) { + KALDI_LOG << "About to test with GPU"; + TestConvolutionComputationForward(computation2, true); + } +#endif + } +} + + +/** + + +// for testing purposes, create a set of input and output indexes for +// a convolution computation that are computable given this model. +static void GetRandomConvolutionIndexes(const ConvolutionModel &model, + std::vector *input_indexes, + std::vector *output_indexes) { + KALDI_ASSERT(model.Check()); + + std::vector > n_x_pairs; + int32 num_n_x_pairs = RandInt(1, 3); + for (int32 i = 0; i < num_n_x_pairs; i++) { + int32 n = RandInt(0, 3), x = RandInt(0, 1); + n_x_pairs.push_back(std::pair(n, x)); + } + SortAndUniq(&n_x_pairs); + num_n_x_pairs = n_x_pairs.size(); + + + // 'output_t_values' is the set of *possible* output + // t values; we'll later sub-sample from these. + std::vector output_t_values; + + { + int32 out_t_start = RandInt(-5, 5), out_t_step = RandInt(1, 3), + num_t_out = RandInt(1, 4); + for (int32 i = 0; i < num_t_out; i++) + output_t_values.push_back(out_t_start + i * out_t_step); + } + + input_indexes->clear(); + output_indexes->clear(); + for (size_t i = 0; i < n_x_pairs.size(); i++) { + std::vector chosen_output_t_values; + while (chosen_output_t_values.empty()) { + for (size_t j = 0; j < output_t_values.size(); j++) + if (RandInt(0, 1) != 0) + chosen_output_t_values.push_back(output_t_values[j]); + } + KALDI_ASSERT(IsSortedAndUniq(chosen_output_t_values)); + + std::set required_input_t_values, + usable_input_t_values; + for (size_t j = 0; j < chosen_output_t_values.size(); j++) { + std::set::const_iterator iter; + int32 t_out = chosen_output_t_values[j]; + for (iter = model.required_time_offsets.begin(); + iter != model.required_time_offsets.end(); iter++) { + int32 offset = *iter; + required_input_t_values.insert(t_out + offset); + } + for (iter = model.all_time_offsets.begin(); + iter != model.all_time_offsets.end(); iter++) { + int32 offset = *iter; + usable_input_t_values.insert(t_out + offset); + } + } + + // add to output_indexes + for (size_t j = 0; j < chosen_output_t_values.size(); j++) { + int32 t_out = chosen_output_t_values[j]; + Index index; + index.n = n_x_pairs[i].first; + index.x = n_x_pairs[i].second; + index.t = t_out; + output_indexes->push_back(index); + } + + std::vector chosen_input_t_values(required_input_t_values.begin(), + required_input_t_values.end()); + for (std::set::const_iterator iter = usable_input_t_values.begin(); + iter != usable_input_t_values.end(); ++iter) { + int32 t = *iter; + if (RandInt(0, 1) == 0) + chosen_input_t_values.push_back(t); + } + SortAndUniq(&chosen_input_t_values); + + // add to input_indexes + for (size_t j = 0; j < chosen_input_t_values.size(); j++) { + int32 t_in = chosen_input_t_values[j]; + Index index; + index.n = n_x_pairs[i].first; + index.x = n_x_pairs[i].second; + index.t = t_in; + input_indexes->push_back(index); + } + } +} + + +void UnitTestTimeHeightConvolutionIo() { + for (int32 i = 0; i < 10; i++) { + KALDI_LOG << "iter = " << i; + // Create a ConvolutionModel and test its I/O. + ConvolutionModel conv_model; + GetRandomConvolutionModel(&conv_model); + std::ostringstream os1, os2; + bool binary = (RandInt(0, 1) == 0); + conv_model.Write(os1, binary); + std::istringstream is(os1.str()); + ConvolutionModel conv_model2; + conv_model2.Read(is, binary); + conv_model2.Write(os2, binary); + KALDI_ASSERT(os1.str() == os2.str() && conv_model2.Check()); + } +} + +void TestComputationIo(const ConvolutionComputation &computation) { + std::ostringstream os1, os2; + bool binary = (RandInt(0, 1) == 0); + computation.Write(os1, binary); + std::istringstream is(os1.str()); + ConvolutionComputation computation2; + computation2.Read(is, binary); + computation2.Write(os2, binary); + KALDI_ASSERT(os1.str() == os2.str()); + computation2.Check(); +} + + +// This function exects indexes.size() == matrix->NumRows(); +// it sets to zero any row i of the matrix for which +// indexes[i].t == kNoTime. +void ZeroBlankRows(const std::vector &indexes, + CuMatrix *matrix) { + KALDI_ASSERT(static_cast(indexes.size()) == matrix->NumRows()); + int32 num_rows = matrix->NumRows(); + if (num_rows == 0) return; + Vector mask(num_rows, kUndefined); + mask.Set(1.0); + const Index *indexes_ptr = &(indexes[0]); + BaseFloat *mask_ptr = mask.Data(); + for (int32 r = 0; r < num_rows; r++) { + if (indexes_ptr[r].t == kNoTime) + mask_ptr[r] = 0.0; + } + CuVector cu_mask; + cu_mask.Swap(&mask); + matrix->MulRowsVec(cu_mask); +} + +// This is a 'dumb' implementation of convolution, created to compare +// with ConvolveForward. +void ConvolveForwardSimple( + const ConvolutionModel &model, + const std::vector &input_indexes, + const std::vector &output_indexes, + const CuMatrixBase &input_cu, + const CuMatrixBase ¶ms_cu, + CuMatrixBase *output_cu) { + // these loops will be very slow on GPU, so do it all on CPU. + Matrix input(input_cu), params(params_cu), + output(*output_cu); + std::unordered_map index_to_row; + int32 input_rows = input.NumRows(), + output_rows = output.NumRows(); + for (int32 r_in = 0; r_in < input_rows; r_in++) { + if (input_indexes[r_in].t != kNoTime) { + index_to_row[input_indexes[r_in]] = r_in; + } + } + int32 num_offsets = model.offsets.size(), + num_filters_in = model.num_filters_in, + num_filters_out = model.num_filters_out, + height_in = model.height_in, + height_out = model.height_out, + height_subsample_out = model.height_subsample_out; + for (int32 r_out = 0; r_out < output_rows; r_out++) { + Index index_out = output_indexes[r_out]; + if (index_out.t == kNoTime) + continue; + SubVector output_row(output, r_out); + for (int32 o = 0; o < num_offsets; o++) { + int32 time_offset = model.offsets[o].time_offset, + height_offset = model.offsets[o].height_offset; + Index index_in(index_out); + index_in.t += time_offset; + std::unordered_map::const_iterator iter = + index_to_row.find(index_in); + if (iter != index_to_row.end()) { + SubMatrix params_part(params, 0, params.NumRows(), + o * num_filters_in, num_filters_in); + int32 r_in = iter->second; + SubVector input_row(input, r_in); + for (int32 h_out_subsampled = 0; + h_out_subsampled < height_out; + h_out_subsampled++) { + int32 h_out = h_out_subsampled * height_subsample_out, + h_in = h_out + height_offset; + if (h_in < 0 || h_in >= height_in) + continue; + SubVector output_part(output_row, + h_out_subsampled * num_filters_out, + num_filters_out), + input_part(input_row, h_in * num_filters_in, num_filters_in); + output_part.AddMatVec(1.0, params_part, kNoTrans, input_part, 1.0); + } + } + } + } + output_cu->CopyFromMat(output); +} + + + +void TestRunningComputation(const ConvolutionModel &conv_model, + const std::vector &input_indexes, + const std::vector &output_indexes, + const ConvolutionComputation &computation) { + CuMatrix input(input_indexes.size(), conv_model.InputDim(), + kSetZero, kStrideEqualNumCols), + output(output_indexes.size(), conv_model.OutputDim(), + kSetZero, kStrideEqualNumCols), + output2(output), + params(conv_model.ParamRows(), conv_model.ParamCols()); + input.SetRandn(); + params.SetRandn(); + ZeroBlankRows(input_indexes, &input); + ConvolveForward(computation, input, params, &output); + ZeroBlankRows(output_indexes, &output); + + ConvolveForwardSimple(conv_model, input_indexes, output_indexes, + input, params, &output2); + KALDI_LOG << "Tested convolution for model: " + << conv_model.Info(); + if (!output.ApproxEqual(output2, 0.001)) { + KALDI_LOG << "Output is: " << output; + KALDI_LOG << "Output2 is: " << output2; + KALDI_ERR << "Convolution test failure."; + } +} + + +void TestDataBackprop(const ConvolutionModel &conv_model, + const std::vector &input_indexes, + const std::vector &output_indexes, + const ConvolutionComputation &computation) { + CuMatrix + input_deriv(input_indexes.size(), conv_model.InputDim(), + kSetZero, kStrideEqualNumCols), + input(input_indexes.size(), conv_model.InputDim(), + kSetZero, kStrideEqualNumCols), + output(output_indexes.size(), conv_model.OutputDim(), + kSetZero, kStrideEqualNumCols), + output_deriv(output_indexes.size(), conv_model.OutputDim(), + kSetZero, kStrideEqualNumCols), + params(conv_model.ParamRows(), conv_model.ParamCols()); + + input.SetRandn(); + params.SetRandn(); + output_deriv.SetRandn(); + + ZeroBlankRows(output_indexes, &output_deriv); + ConvolveBackwardData(computation, params, output_deriv, &input_deriv); + ZeroBlankRows(input_indexes, &input_deriv); + ZeroBlankRows(input_indexes, &input); + + // define the objf as TraceMatMat(output_deriv, output, kTrans). + // we can work it out from the backpropagated data-derivative. + BaseFloat expected_objf = TraceMatMat(input_deriv, input, kTrans); + + ConvolveForward(computation, input, params, &output); + ZeroBlankRows(output_indexes, &output); + + BaseFloat observed_objf = TraceMatMat(output, output_deriv, kTrans); + + KALDI_LOG << "Expected objf = " << expected_objf + << ", observed objf = " << observed_objf; + if (!ApproxEqual(expected_objf, observed_objf, 0.1) && + fabs(expected_objf) < 1.0) { + KALDI_ERR << "Difference in objf too large."; + } +} + + +void TestParamsBackprop(const ConvolutionModel &conv_model, + const std::vector &input_indexes, + const std::vector &output_indexes, + const ConvolutionComputation &computation) { + CuMatrix + input(input_indexes.size(), conv_model.InputDim(), + kSetZero, kStrideEqualNumCols), + output(output_indexes.size(), conv_model.OutputDim(), + kSetZero, kStrideEqualNumCols), + output_deriv(output_indexes.size(), conv_model.OutputDim(), + kSetZero, kStrideEqualNumCols), + params(conv_model.ParamRows(), conv_model.ParamCols()), + params_deriv(conv_model.ParamRows(), conv_model.ParamCols()); + + input.SetRandn(); + params.SetRandn(); + output_deriv.SetRandn(); + + BaseFloat alpha = 0.5 * RandInt(1, 3); + + ZeroBlankRows(output_indexes, &output_deriv); + ZeroBlankRows(input_indexes, &input); + + ConvolveBackwardParams(computation, input, output_deriv, alpha, + ¶ms_deriv); + + BaseFloat expected_objf = TraceMatMat(params_deriv, params, kTrans) / alpha; + + ConvolveForward(computation, input, params, &output); + + ZeroBlankRows(output_indexes, &output); + + BaseFloat observed_objf = TraceMatMat(output, output_deriv, kTrans); + + KALDI_LOG << "Expected objf = " << expected_objf + << ", observed objf = " << observed_objf; + if (!ApproxEqual(expected_objf, observed_objf, 0.1) && + fabs(expected_objf) < 1.0) { + KALDI_ERR << "Difference in objf too large."; + } +} + + + +void UnitTestTimeHeightConvolutionCompile() { + for (int32 i = 0; i < 10; i++) { + KALDI_LOG << "iter = " << i; + // Create a ConvolutionModel + ConvolutionModel conv_model; + GetRandomConvolutionModel(&conv_model); + std::vector input_indexes, output_indexes; + GetRandomConvolutionIndexes(conv_model, &input_indexes, &output_indexes); + + ConvolutionComputationOptions opts; + ConvolutionComputation computation; + std::vector input_indexes_modified, output_indexes_modified; + CompileConvolutionComputation(conv_model, input_indexes, output_indexes, + opts, &computation, + &input_indexes_modified, + &output_indexes_modified); + TestComputationIo(computation); + TestRunningComputation(conv_model, + input_indexes_modified, + output_indexes_modified, + computation); + TestDataBackprop(conv_model, + input_indexes_modified, + output_indexes_modified, + computation); + TestParamsBackprop(conv_model, + input_indexes_modified, + output_indexes_modified, + computation); + std::ostringstream os; + os << "\nInput-indexes: "; + WriteIndexVector(os, false, input_indexes); + os << "\nInput-indexes-modified: "; + WriteIndexVector(os, false, input_indexes_modified); + os << "\nOutput-indexes: "; + WriteIndexVector(os, false, output_indexes); + os << "\nOutput-indexes-modified: "; + WriteIndexVector(os, false, output_indexes_modified); + KALDI_LOG << os.str(); + } +} + + +void UnitTestTimeHeightConvolution() { + UnitTestTimeHeightConvolutionIo(); + UnitTestTimeHeightConvolutionCompile(); +} + +*/ + +} // namespace cudnn_convolution +} // namespace nnet3 +} // namespace kaldi + + +int main() { + using namespace kaldi; + using namespace kaldi::nnet3; + using namespace kaldi::nnet3::cudnn_convolution; + + + for (int32 loop = 0; loop < 2; loop++) { +#if HAVE_CUDA == 1 + CuDevice::Instantiate().SetDebugStrideMode(true); + if (loop == 0) + CuDevice::Instantiate().SelectGpuId("no"); // -1 means no GPU + else + CuDevice::Instantiate().SelectGpuId("optional"); // -2 .. automatic selection +#endif + TestConvolutionComputationConfig(); + TestConvolutionComputation(); + } +} diff --git a/src/nnet3/convolution-cudnn.cc b/src/nnet3/convolution-cudnn.cc new file mode 100644 index 00000000000..d291387c350 --- /dev/null +++ b/src/nnet3/convolution-cudnn.cc @@ -0,0 +1,657 @@ +// nnet3/convolution-cudnn.cc + +// Copyright 2018 Daniel Galvez +// 2018 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "nnet3/convolution-cudnn.h" + +namespace kaldi { +namespace nnet3 { +namespace cudnn_convolution { + + +namespace { +// Note: anonymous namespaces are now preferred (by the C++ standard) over +// static variables. +const BaseFloat ONE(1.0); +const BaseFloat ZERO(0.0); + +template +void CheckCorrectness(CudnnAlgoPerfT perf_results, const char* function) { + if (perf_results.status != CUDNN_STATUS_SUCCESS) + KALDI_ERR << function << " had an error: " << + cudnnGetErrorString(perf_results.status); +} +} + + +void ConvolutionComputationConfig::Check() { + KALDI_ASSERT(num_images > 0 && num_channels_out > 0 && + num_channels_in > 0 && filter_height > 0 && filter_width > 0); + KALDI_ASSERT(filter_stride_vertical > 0 && filter_stride_horizontal > 0 && + filter_dilation_vertical > 0 && filter_dilation_horizontal > 0); + KALDI_ASSERT(input_image_height > 0 && input_image_width > 0 && + zero_padding_vertical >= 0 && zero_padding_horizontal >= 0); +} + +void ConvolutionComputationConfig::ComputeOutputImageSize() { + { // This blocks deals with the vertical direction. + + // 'filter_height_reduction' is the amount by which the height of the filter patch + // reduces the effective height of the input image. It's the distance between + // the first and last pixels of the filter patch. E.g. in a 3x3 kernel it + // would be 2. + int32 filter_height_reduction = (filter_height - 1) * filter_dilation_vertical; + // 'modified_input_height' is the number of times we can shift the filter patch + // (not yet taking account of any filter stride). It's a kind of augmented input-image + // height, after applying zero-padding and subtracting filter_height_reduction. + int32 modified_input_height = + input_image_height - filter_height_reduction + (zero_padding_vertical * 2), + s = filter_stride_vertical; + + // output_image_height equals reduced_input_height divided by s (but rounding + // up), which is the number of times we can shift the filter patch by + // filter_stride_vertical_. + output_image_height = (modified_input_height + s - 1) / s; + } + + { // This blocks deals with the horizontal direction. + + // 'filter_width_reduction' is the amount by which the width of the filter patch + // reduces the effective width of the input image. It's the distance between + // the first and last pixels of the filter patch. E.g. in a 3x3 kernel it + // would be 2. + int32 filter_width_reduction = (filter_width - 1) * filter_dilation_horizontal; + // 'modified_input_width' is the number of times we can shift the filter patch + // (not yet taking account of any filter stride). It's a kind of augmented input-image + // width, after applying zero-padding and subtracting filter_width_reduction. + int32 modified_input_width = + input_image_width - filter_width_reduction + (zero_padding_horizontal * 2), + s = filter_stride_horizontal; + + // output_image_width equals reduced_input_width divided by s (but rounding + // up), which is the number of times we can shift the filter patch by + // filter_stride_horizontal_. + output_image_width = (modified_input_width + s - 1) / s; + } +} + +void ConvolutionComputationConfig::Write(std::ostream &os, bool binary) const { + WriteToken(os, binary, ""); + WriteBasicType(os, binary, num_images); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, num_channels_in); + WriteBasicType(os, binary, num_channels_out); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, filter_height); + WriteBasicType(os, binary, filter_width); + WriteBasicType(os, binary, filter_stride_vertical); + WriteBasicType(os, binary, filter_stride_horizontal); + WriteBasicType(os, binary, filter_dilation_vertical); + WriteBasicType(os, binary, filter_dilation_horizontal); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, input_image_height); + WriteBasicType(os, binary, input_image_width); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, zero_padding_vertical); + WriteBasicType(os, binary, zero_padding_horizontal); + WriteToken(os, binary, ""); +} + +void ConvolutionComputationConfig::Read(std::istream &is, bool binary) { + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &num_images); + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &num_channels_in); + ReadBasicType(is, binary, &num_channels_out); + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &filter_height); + ReadBasicType(is, binary, &filter_width); + ReadBasicType(is, binary, &filter_stride_vertical); + ReadBasicType(is, binary, &filter_stride_horizontal); + ReadBasicType(is, binary, &filter_dilation_vertical); + ReadBasicType(is, binary, &filter_dilation_horizontal); + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &input_image_height); + ReadBasicType(is, binary, &input_image_width); + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &zero_padding_vertical); + ReadBasicType(is, binary, &zero_padding_horizontal); + ExpectToken(is, binary, ""); + ComputeOutputImageSize(); +} + + +ConvolutionComputation::ConvolutionComputation( + const ConvolutionComputationConfig &config): config_(config) { + config_.Check(); + config_.ComputeOutputImageSize(); +#if HAVE_CUDA == 1 + if (CuDevice::Instantiate().Enabled()) { + InitCudnn(); + } +#endif +} + +ConvolutionComputation::ConvolutionComputation() { +#if HAVE_CUDA == 1 + descriptors_initialized_ = false; +#endif +} + + +#if HAVE_CUDA == 1 + +#if (KALDI_DOUBLEPRECISION != 0) +#define CUDNN_DATA_BASEFLOAT CUDNN_DATA_DOUBLE +#else +#define CUDNN_DATA_BASEFLOAT CUDNN_DATA_FLOAT +#endif + +void ConvolutionComputation::InitCudnn() { + descriptors_initialized_ = true; + + const ConvolutionComputationConfig &c = config_; + + CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&input_desc_)); + CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&output_desc_)); + CUDNN_SAFE_CALL(cudnnCreateFilterDescriptor(¶ms_desc_)); + CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&bias_desc_)); + CUDNN_SAFE_CALL(cudnnCreateConvolutionDescriptor(&conv_desc_)); + + // Caution: in the following call, the 'height' and 'width' are swapped + // relative to what the CUDNN interface specifies; this is because Kaldi's + // notion of what is height vs. width is opposite to CUDNN's. (There + // are good reasons for this). + // We use cudnnSetTensorNdDescriptor because of bugs in cudnnSetTensor4dDescriptor. + int in_dims[4] = {c.num_images, c.num_channels_in, c.input_image_width, + c.input_image_height}; + int in_stride[4] = {c.num_channels_in * c.input_image_width * c.input_image_height, + c.input_image_width * c.input_image_height, + c.input_image_height, 1}; + CUDNN_SAFE_CALL( + cudnnSetTensorNdDescriptor(input_desc_, CUDNN_DATA_BASEFLOAT, 4, in_dims, + in_stride)); + // Again: width and height are swapped. + CUDNN_SAFE_CALL( + cudnnSetConvolution2dDescriptor( + conv_desc_, + c.zero_padding_horizontal, c.zero_padding_vertical, + c.filter_stride_horizontal, c.filter_stride_vertical, + c.filter_dilation_horizontal, c.filter_dilation_vertical, + CUDNN_CROSS_CORRELATION, // TODO: Double check this! + CUDNN_DATA_BASEFLOAT)); + + // Set dimensions of the filters (linear parameters). + // Again: width and height are swapped. Per the CUDNN documentation at + // https://docs.nvidia.com/deeplearning/sdk/pdf/cuDNN-Developer-Guide.pdf for + // cudnnSetFilter4dDescriptor, setting CUDNN_TENSOR_NCHW as the layout + // corresponds to KCRS, meaning: num-channels-out, num-channels-in, height, width, + // where 'height' and 'width' are the filter height and width respectively (e.g. 3 + // and 3 for a 3x3 patch); and these are swapped w.r.t. Kaldi's notion of height and + // width, so as far as Kaldi is concerned, the strides are, from largest to + // smallest: num-channels-out, width, height, num-channels-in: so as far + // as Kaldi is concerned the layout is KCWH (== KCSR, in their notation). + CUDNN_SAFE_CALL( + cudnnSetFilter4dDescriptor(params_desc_, CUDNN_DATA_BASEFLOAT, + CUDNN_TENSOR_NCHW, c.num_channels_out, + c.num_channels_in, c.filter_width, + c.filter_height)); + + int32 kaldi_width_cudnn_height, kaldi_height_cudnn_width, unused; + CUDNN_SAFE_CALL( + cudnnGetConvolution2dForwardOutputDim(conv_desc_, input_desc_, + params_desc_, + &unused, &unused, + &kaldi_width_cudnn_height, + &kaldi_height_cudnn_width)); + + if (kaldi_height_cudnn_width != c.output_image_height) + KALDI_ERR << "Code error: the height from CUDNN " << kaldi_height_cudnn_width + << " does not match our value " << c.output_image_height; + if (kaldi_width_cudnn_height != c.output_image_width) + KALDI_ERR << "Code error: the width from CUDNN " << kaldi_width_cudnn_height + << " does not match our value " << c.output_image_width; + + // These two member functions depend only on input_desc_, + // conv_desc_, and params_desc_, so they are safe to call now. + int out_dims[4] = {c.num_images, c.num_channels_out, c.output_image_width, + c.output_image_height}; + int out_stride[4] = {c.num_channels_out * c.output_image_width * c.output_image_height, + c.output_image_width * c.output_image_height, + c.output_image_height, 1}; + CUDNN_SAFE_CALL( + cudnnSetTensorNdDescriptor(output_desc_, CUDNN_DATA_BASEFLOAT, 4, out_dims, + out_stride)); + + // Since the output tensor shape is NKHW, we need the bias to be + // four-dimensional and the length of each dimension of the bias + // equal to either one or the output tensor's corresponding + // length. Singleton dimensions are broadcasted. + int bias_dims[4] = {1, c.num_channels_out, 1, 1}; + int bias_stride[4] = {c.num_channels_out, 1, 1, 1}; + CUDNN_SAFE_CALL( + cudnnSetTensorNdDescriptor(bias_desc_, CUDNN_DATA_BASEFLOAT, 4, + bias_dims, bias_stride)); + + int32 requested_algo_count, returned_algo_count; + CUDNN_SAFE_CALL(cudnnGetConvolutionForwardAlgorithmMaxCount( + GetCudnnHandle(), &requested_algo_count)); + + cudnnConvolutionFwdAlgoPerf_t *forward_results = + new cudnnConvolutionFwdAlgoPerf_t[requested_algo_count]; + CUDNN_SAFE_CALL(cudnnFindConvolutionForwardAlgorithm( + GetCudnnHandle(), + input_desc_, + params_desc_, + conv_desc_, + output_desc_, + requested_algo_count, + &returned_algo_count, + forward_results)); + + KALDI_ASSERT(returned_algo_count > 0 && + "No algorithms were returned by CUDNN."); + const cudnnConvolutionFwdAlgoPerf_t& best_forward = forward_results[0]; + CheckCorrectness(best_forward, "cudnnFindConvolutionForwardAlgorithm"); + fwd_algo_ = best_forward.algo; + delete [] forward_results; + + CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( + GetCudnnHandle(), &requested_algo_count)); + cudnnConvolutionBwdFilterAlgoPerf_t *backward_filter_results = + new cudnnConvolutionBwdFilterAlgoPerf_t[requested_algo_count]; + CUDNN_SAFE_CALL(cudnnFindConvolutionBackwardFilterAlgorithm( + GetCudnnHandle(), + input_desc_, + output_desc_, + conv_desc_, + params_desc_, + requested_algo_count, + &returned_algo_count, + backward_filter_results)); + KALDI_ASSERT(returned_algo_count > 0 && + "No algorithms were returned by CUDNN."); + const cudnnConvolutionBwdFilterAlgoPerf_t& best_backward_filter = + backward_filter_results[0]; + CheckCorrectness(best_backward_filter, + "cudnnFindConvolutionBackwardFilterAlgorithm"); + bwd_filter_algo_ = best_backward_filter.algo; + delete [] backward_filter_results; + + CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( + GetCudnnHandle(), &requested_algo_count)); + cudnnConvolutionBwdDataAlgoPerf_t *backward_data_results = + new cudnnConvolutionBwdDataAlgoPerf_t[requested_algo_count]; + CUDNN_SAFE_CALL(cudnnFindConvolutionBackwardDataAlgorithm( + GetCudnnHandle(), + params_desc_, + output_desc_, + conv_desc_, + input_desc_, + requested_algo_count, + &returned_algo_count, + backward_data_results)); + KALDI_ASSERT(returned_algo_count > 0 && + "No algorithms were returned by CUDNN."); + const cudnnConvolutionBwdDataAlgoPerf_t& best_backward_data = + backward_data_results[0]; + CheckCorrectness(best_backward_data, + "cudnnFindConvolutionBackwardDataAlgorithm"); + bwd_data_algo_ = best_backward_data.algo; + delete [] backward_data_results; + + ComputeTempSpaceSizes(); +} +#endif + +#if HAVE_CUDA == 1 +void ConvolutionComputation::ComputeTempSpaceSizes() { + CUDNN_SAFE_CALL(cudnnGetConvolutionForwardWorkspaceSize( + GetCudnnHandle(), + input_desc_, + params_desc_, + conv_desc_, + output_desc_, + fwd_algo_, + &temp_space_required_forward_)); + + CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize( + GetCudnnHandle(), + params_desc_, + output_desc_, + conv_desc_, + input_desc_, + bwd_data_algo_, + &temp_space_required_backward_data_)); + + CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize( + GetCudnnHandle(), + input_desc_, + output_desc_, + conv_desc_, + params_desc_, + bwd_filter_algo_, + &temp_space_required_backward_filter_)); +} +#endif + +#if HAVE_CUDA == 1 +void ConvolutionComputation::DestroyCudnn() { + CUDNN_SAFE_CALL(cudnnDestroyTensorDescriptor(input_desc_)); + CUDNN_SAFE_CALL(cudnnDestroyTensorDescriptor(output_desc_)); + CUDNN_SAFE_CALL(cudnnDestroyFilterDescriptor(params_desc_)); + CUDNN_SAFE_CALL(cudnnDestroyTensorDescriptor(bias_desc_)); + CUDNN_SAFE_CALL(cudnnDestroyConvolutionDescriptor(conv_desc_)); +} +#endif + +ConvolutionComputation::~ConvolutionComputation() { +#if HAVE_CUDA == 1 + if (CuDevice::Instantiate().Enabled() && descriptors_initialized_) + DestroyCudnn(); +#endif +} + + +void ConvolutionComputation::Read(std::istream &is, bool binary) { + config_.Read(is, binary); +#if HAVE_CUDA == 1 + if (CuDevice::Instantiate().Enabled()) { + InitCudnn(); + } +#endif +} + + +void ConvolutionComputation:: +ConvolveForward(const CuMatrixBase &input, + const CuMatrixBase ¶ms, + const CuVectorBase *bias, + CuMatrixBase *output) const { + const ConvolutionComputationConfig &c = config_; + // Check some dimensions. + KALDI_ASSERT( + input.NumRows() == c.num_images * c.input_image_width && + input.NumCols() == c.input_image_height * c.num_channels_in && + input.Stride() == input.NumCols() && + params.NumRows() == c.num_channels_out && + params.NumCols() == c.num_channels_in * c.filter_height * c.filter_width && + params.Stride() == params.NumCols()); + KALDI_ASSERT( + (bias == nullptr || bias->Dim() == c.num_channels_out) && + output->NumRows() == c.num_images * c.output_image_width && + output->NumCols() == c.output_image_height * c.num_channels_out && + output->Stride() == output->NumCols()); + +#if HAVE_CUDA == 1 + if (CuDevice::Instantiate().Enabled()) { + CuVector temp_space(temp_space_required_forward_ / + sizeof(BaseFloat), kUndefined); + + CUDNN_SAFE_CALL(cudnnConvolutionForward( + GetCudnnHandle(), + &ONE, + input_desc_, + input.Data(), + params_desc_, + params.Data(), + conv_desc_, + fwd_algo_, + temp_space.Data(), + temp_space.Dim() * sizeof(BaseFloat), + &ZERO, + output_desc_, + output->Data())); + if (bias != nullptr) { + CUDNN_SAFE_CALL(cudnnAddTensor(GetCudnnHandle(), + &ONE, bias_desc_, bias->Data(), &ONE, + output_desc_, output->Data())); + } + } else +#endif + { + ConvolveForward(input.Mat(), params.Mat(), &(bias->Vec()), + &(output->Mat())); + } +} + + +void ConvolutionComputation:: +ConvolveForward(const MatrixBase &input, + const MatrixBase ¶ms, + const VectorBase *bias, + MatrixBase *output) const { + const ConvolutionComputationConfig &c = config_; + // Check some dimensions. + KALDI_ASSERT( + input.NumRows() == c.num_images * c.input_image_width && + input.NumCols() == c.input_image_height * c.num_channels_in && + input.Stride() == input.NumCols() && + params.NumRows() == c.num_channels_out && + params.NumCols() == c.num_channels_in * c.filter_height * c.filter_width && + params.Stride() == params.NumCols()); + KALDI_ASSERT( + (bias != nullptr || bias->Dim() == c.num_channels_out) && + output->NumRows() == c.num_images * c.output_image_width && + output->NumCols() == c.output_image_height * c.num_channels_out && + output->Stride() == output->NumCols()); + + + if (bias != nullptr) { // Deal with the bias. + SubMatrix output_rearranged( + output->Data(), + c.num_images * c.output_image_width * c.output_image_height, + c.num_channels_out, c.num_channels_out); + output_rearranged.CopyRowsFromVec(*bias); + } + + Matrix params_rearranged(c.filter_width * c.filter_height, + c.num_channels_out * c.num_channels_in, + kUndefined, kStrideEqualNumCols); + ConvertParams(params, ¶ms_rearranged); + + // The strides in 'input' and 'output' respectively from a certain pixel of one + // image to the same pixel in another image. + int32 input_image_stride = + c.input_image_width * c.num_channels_in * c.input_image_height, + output_image_stride = + c.output_image_width * c.num_channels_out * c.output_image_height; + + // We're using variable names w (as in width) for horizontal positions and h + // (as in height) for vertical positions. This is perhaps not ideal. + for (int32 output_w = 0; output_w < c.output_image_width; output_w++) { + for (int32 output_h = 0; output_h < c.output_image_height; output_h++) { + for (int32 filter_h = 0; filter_h < c.filter_height; filter_h++) { + //int32 filter_h_flipped = c.filter_height - 1 - filter_h; + int32 filter_h_flipped = filter_h; // we don't flip. + int32 input_h = output_h * c.filter_stride_vertical + - c.zero_padding_vertical + + filter_h * c.filter_dilation_vertical; + if (input_h < 0 || input_h >= c.input_image_height) + continue; + for (int32 filter_w = 0; filter_w < c.filter_width; filter_w++) { + // int32 filter_w_flipped = c.filter_width - 1 - filter_w; + int32 filter_w_flipped = filter_w; // we don't flip. + int32 input_w = output_w * c.filter_stride_horizontal + - c.zero_padding_horizontal + + filter_w * c.filter_dilation_horizontal; + + if (input_w < 0 || input_w >= c.input_image_width) + continue; + + const BaseFloat *params_data = params_rearranged.RowData( + filter_w_flipped * c.filter_height + filter_h_flipped); + SubMatrix this_params(params_data, + c.num_channels_out, + c.num_channels_in, c.num_channels_in); + const BaseFloat *input_data = + input.RowData(input_w) + input_h * c.num_channels_in; + SubMatrix this_input_pixel(input_data, + c.num_images, + c.num_channels_in, + input_image_stride); + + + const BaseFloat *output_data = + output->RowData(output_w) + output_h * c.num_channels_out; + SubMatrix this_output_pixel(output_data, + c.num_images, + c.num_channels_out, + output_image_stride); + this_output_pixel.AddMatMat(1.0, this_input_pixel, kNoTrans, + this_params, kTrans, 1.0); + } + } + } + } +} + +void ConvolutionComputation:: +ConvolveBackwardData(const CuMatrixBase ¶ms, + const CuMatrixBase &output_deriv, + CuMatrixBase *input_deriv) const { +#if HAVE_CUDA == 1 + if (CuDevice::Instantiate().Enabled()) { + CuVector temp_space(temp_space_required_backward_data_ / + sizeof(BaseFloat), kUndefined); + CUDNN_SAFE_CALL(cudnnConvolutionBackwardData( + GetCudnnHandle(), + &ONE, + params_desc_, + params.Data(), + output_desc_, + output_deriv.Data(), + conv_desc_, + bwd_data_algo_, + temp_space.Data(), + temp_space.Dim() * sizeof(BaseFloat), + &ZERO, + input_desc_, + input_deriv->Data())); + } else +#endif + { + // TODO + } +} + +void ConvolutionComputation:: +ConvolveBackwardData(const MatrixBase ¶ms, + const MatrixBase &output_deriv, + MatrixBase *input_deriv) const { + // TODO +} + + + +void ConvolutionComputation:: +ConvolveBackwardParams(const CuMatrixBase &output_deriv, + const CuMatrixBase &input, + BaseFloat alpha, + CuMatrixBase *params_deriv) const { +#if HAVE_CUDA == 1 + if (CuDevice::Instantiate().Enabled()) { + CuVector temp_space(temp_space_required_backward_filter_ / + sizeof(BaseFloat), kUndefined); + CUDNN_SAFE_CALL(cudnnConvolutionBackwardFilter( + GetCudnnHandle(), + &alpha, + input_desc_, + input.Data(), + output_desc_, + output_deriv.Data(), + conv_desc_, + bwd_filter_algo_, + temp_space.Data(), + temp_space.Dim() * sizeof(BaseFloat), + &ONE, + params_desc_, + params_deriv->Data())); + } else +#endif + { + ConvolveBackwardParams(output_deriv.Mat(), input.Mat(), + alpha, &(params_deriv->Mat())); + } +} + + +void ConvolutionComputation:: +ConvolveBackwardParams(const MatrixBase &output_deriv, + const MatrixBase &input, + BaseFloat alpha, + MatrixBase *params_deriv) const { + // TODO +} + + +void ConvolutionComputation:: +ConvolveBackwardBias(const CuMatrixBase &output_deriv, + BaseFloat alpha, + CuVectorBase *bias_deriv) const { + if (bias_deriv == nullptr) { + return; + } +#if HAVE_CUDA == 1 + if (CuDevice::Instantiate().Enabled()) { + CUDNN_SAFE_CALL(cudnnConvolutionBackwardBias( + GetCudnnHandle(), + &alpha, + output_desc_, + output_deriv.Data(), + &ONE, + bias_desc_, + bias_deriv->Data())); + } else +#endif + { + ConvolveBackwardBias(output_deriv.Mat(), alpha, &(bias_deriv->Vec())); + } +} + +void ConvolutionComputation:: +ConvolveBackwardBias(const MatrixBase &output_deriv, + BaseFloat alpha, + VectorBase *bias_deriv) const { + if (bias_deriv == nullptr) { + return; + } + // TODO. +} + + +// This function, called only if we are not using the GPU, converts +// the params from KWHC format to WHKC format (which is more convenient +// when using the CPU. Note: K == channels-out, C == channels-in. +void ConvolutionComputation::ConvertParams( + const MatrixBase ¶ms, + MatrixBase *params_rearranged) const { + const ConvolutionComputationConfig &c = config_; + // Reinterpret params as params_reinterpret which is of dimension KC * WH (instead of K * CWH). + SubMatrix params_reinterpret(params.Data(), + c.num_channels_out * c.num_channels_in, + c.filter_width * c.filter_height, + c.filter_width * c.filter_height); + params_rearranged->CopyFromMat(params_reinterpret, kTrans); +} + + +} // namespace cudnn_convolution +} // namespace nnet3 +} // namespace kaldi diff --git a/src/nnet3/convolution-cudnn.h b/src/nnet3/convolution-cudnn.h new file mode 100644 index 00000000000..a84bc995b64 --- /dev/null +++ b/src/nnet3/convolution-cudnn.h @@ -0,0 +1,279 @@ +// nnet3/convolution-cudnn.h + +// Copyright 2018 Daniel Galvez +// 2018 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_NNET3_NNET_CUDNN_CONVOLUTION_H_ +#define KALDI_NNET3_NNET_CUDNN_CONVOLUTION_H_ + + +#include "base/kaldi-common.h" +#include "matrix/matrix-lib.h" +#include "nnet3/convolution.h" +#if HAVE_CUDA == 1 +#include +#endif + + +namespace kaldi { +namespace nnet3 { +namespace cudnn_convolution { + +/** This struct contains information about a specific convolution + computation. It combines information about the model and the data. + The examples below are very arbitrary. + */ +struct ConvolutionComputationConfig { + // The number of images we are working on, e.g. 128. + int32 num_images; + // The number of input channels, e.g. 32 + int32 num_channels_in; + // The number of output channels, e.g. 64. + int32 num_channels_out; + // The number of pixels in the filter patch in the vertical direction, e.g. 3. + int32 filter_height; + // The number of pixels in the filter patch in the horizontal direction, e.g. 3. + int32 filter_width; + // The vertical stride of the filter, normally 1 but might be (e.g.) 2 if we + // are subsampling in the vertical (usually: frequency) dimension. + int32 filter_stride_vertical; + // The horizontal stride of the filter, normally 1 but might be (e.g.) 3 at a + // certain layer of the network if we are training a chain model with a + // frame-subsampling-factor of 3. + int32 filter_stride_horizontal; + // Normally 1, if this is more than 1 the pixels of the image patch will be + // spaced apart from each other. + int32 filter_dilation_vertical; + // Normally 1, if this is more than 1 the pixels of the image patch will be + // spaced apart from each other. + int32 filter_dilation_horizontal; + // The height of the input image, e.g. this is often 40 in speech applications, + // subsampled to 20 or 10 later in the network. + int32 input_image_height; + // The width of the input image, which will be the same as the number of + // frames being computed. + int32 input_image_width; + // The amount of zero-padding in the height (normally: frequency) dimension; + // this number of zero frames are added at both top and bottom of the input. + // Will often be 1, if you are using 3x3 kernels and don't want to + // reduce the height of the image. + int32 zero_padding_vertical; + // The amount of zero-padding in the time dimension, meaning the number of + // zero frames that we implicitly add to the beginning and end of the utterance. + // This will normally be zero because in Kaldi ASR recipes we generally don't + // do zero padding, but duplicate the first and last frame of the input to + // match the amount of left and right context that the neural network requires. + int32 zero_padding_horizontal; + + + // The height of the output image. The user does not have to set this; + // it will be computed when you call ComputeOutputImageSize(). + int32 output_image_height; + // The width of the output image. The user does not have to set this; + // it will be computed when you call ComputeOutputImageSize(). + int32 output_image_width; + + + // Checks that all the configuration variables except output_image_height + // and output_image_width have allowed values. + void Check(); + + // Computes output_image_height and output_image_width from the other + // configuration values. + void ComputeOutputImageSize(); + + void Write(std::ostream &os, bool binary) const; + + // Note: Read() automatically calls ComputeOutputImageSize(). + void Read(std::istream &is, bool binary); + +}; + + +/** + This object allows you to execute a convolution computation, and its backprop, + on either a GPU (using CUDNN), or on CPU using a compatible interface. + + This object is quite lightweight: it only contains some structural data and a + few smallish CUDNN descriptors that are derived from it. + +*/ +class ConvolutionComputation final { +public: + // Note: you don't have to have done ComputeOutputImageSize() on 'config', + // this class will do it in the constructor. + ConvolutionComputation(const ConvolutionComputationConfig &config); + + // This constructor may be used prior to calling Read(). + ConvolutionComputation(); + + const ConvolutionComputationConfig &Config() const { return config_; } + + ~ConvolutionComputation(); + + /* + For an explanation of the notation below (e.g. NWHC): + + N is equivalent to num_images + C is equivalent to num_channels_in + K is equivalent to num_channels_out + H is equivalent to input_image_height or output_image_height (for images) or + filter_height (for filter parameters). + W is equivalent to input_image_width or output_image_width (for images) or + filter_width (for filter parameters). + and the order of letters is from highest to lowest stride, e.g in + + In NWHC, N would have the highest stride, and C a stride of 1. + + Caution: for convenience, given the way nnet3 works, we flip the notion of + height and width that CUDNN uses, so our height is CUDNN's width, and vice + versa. This is not visible to the user; we mention it just in case + those familiar with CUDNN get surprised at the order + + @param [in] input NWHC fully-packed tensor, with NumRows() == N * W + @param [in] params KCWH fully-packed tensor, with NumRows() == K. + @param [in] pointer bias vector of length K (or NULL if we're not using a bias). + @param [out] output Pre-allocated NWHK fully-packed tensor, with NumRows() == N * W. + */ + void ConvolveForward(const CuMatrixBase &input, + const CuMatrixBase ¶ms, + const CuVectorBase *bias, + CuMatrixBase *output) const; + + /** + * @param [in] params KCWH fully-packed tensor, with NumRows() == K + * @param [in] output_deriv NWHK fully-packed tensor, with NumRows() == N * W + * @param [out] input_deriv Pre-allocated NWHC fully-packed tensor, with + * NumRows() == N * W + */ + void ConvolveBackwardData(const CuMatrixBase ¶ms, + const CuMatrixBase &output_deriv, + CuMatrixBase *input_deriv) const; + + /** + * @param [in] output_deriv NWHK fully-packed tensor, with NumRows() == N * W. + * @param [in] input NWHC fully-packed tensor, with NumRows() == N * W. + * @param [in] alpha + * params_deriv := alpha * gradient_computed + params_deriv + * @param [in] params KCWH fully-packed tensor, with NumRows() == K + * @param [out] params_deriv Pre-allocated KCWH fully-packed tensor, + * with NumRows() == K. + */ + void ConvolveBackwardParams(const CuMatrixBase &output_deriv, + const CuMatrixBase &input, + BaseFloat alpha, + CuMatrixBase *params_deriv) const; + + /** + * @param [in] output_deriv NWHK fully-packed tensor, with NumRows() * N * W. + * @param [in] alpha + * bias_deriv := alpha * gradient_computed + bias_deriv + * @param [out] bias_deriv Pre-allocated vector of length K + */ + void ConvolveBackwardBias(const CuMatrixBase &output_deriv, + BaseFloat alpha, + CuVectorBase *bias_deriv) const; + + // The CPU versions of the functions declared above allow the user to use the + // CPU even if a GPU is active. They are also called by the versions of the + // same name that take CuMatrix types, in the case when either we did not + // compile for CUDA or we did but we are not using a GPU. + void ConvolveForward(const MatrixBase &input, + const MatrixBase ¶ms, + const VectorBase *bias, + MatrixBase *output) const; + void ConvolveBackwardData(const MatrixBase ¶ms, + const MatrixBase &output_deriv, + MatrixBase *input_deriv) const; + void ConvolveBackwardParams(const MatrixBase &output_deriv, + const MatrixBase &input, + BaseFloat alpha, + MatrixBase *params_deriv) const; + void ConvolveBackwardBias(const MatrixBase &output_deriv, + BaseFloat alpha, + VectorBase *bias_deriv) const; + + + + + void Write(std::ostream &os, bool binary) const { config_.Write(os, binary); } + + void Read(std::istream &is, bool binary); + +private: +#if HAVE_CUDA == 1 + // initialize the various descriptors; only called if compiled for CUDA + // AND we are using a GPU. + void InitCudnn(); + // ComputeTempSpaceSizes() is called from InitCudnn(); it sets the + // temp_space_*_ member variables. + void ComputeTempSpaceSizes(); + // Destroy the descriptors. + void DestroyCudnn(); +#endif + + // Called from the constructor and Read(), this sets output_image_height_. + void ComputeOutputImageHeight(); + // Called from the constructor and Read(), this sets output_image_width_. + void ComputeOutputImageWidth(); + + + // This function, called only if we are not using the GPU, converts + // the params from KCWH format to WHKC format (which is more convenient + // when using the CPU). params and params_rearranged must both be + // packed (Stride() == NumCols()), params must have num-rows equal to K + // (num_channels_out_), and params_rearranged must have num-rows equal + // to to WH (filter_width_ * filter_height_). + void ConvertParams(const MatrixBase ¶ms, + MatrixBase *params_rearranged) const; + // This function does the opposite transformation of what ConvertParams() + // does. + void ConvertParamsBack(const MatrixBase ¶ms_rearranged, + MatrixBase *params) const; + + + + ConvolutionComputationConfig config_; + + +#if HAVE_CUDA == 1 + bool descriptors_initialized_; + cudnnTensorDescriptor_t input_desc_; + cudnnTensorDescriptor_t output_desc_; + cudnnFilterDescriptor_t params_desc_; + cudnnTensorDescriptor_t bias_desc_; + cudnnConvolutionDescriptor_t conv_desc_; + + cudnnConvolutionFwdAlgo_t fwd_algo_; + cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo_; + cudnnConvolutionBwdDataAlgo_t bwd_data_algo_; + + // The units of the following are all in bytes. + size_t temp_space_required_forward_; + size_t temp_space_required_backward_data_; + size_t temp_space_required_backward_filter_; +#endif + + +}; + +} // namespace cudnn_convolution +} // namespace nnet3 +} // namespace kaldi + +#endif // KALDI_NNET3_NNET_CUDNN_CONVOLUTION_H_ diff --git a/src/nnet3/nnet-convolutional-component.h b/src/nnet3/nnet-convolutional-component.h index 279cec321dd..e98dd0468b4 100644 --- a/src/nnet3/nnet-convolutional-component.h +++ b/src/nnet3/nnet-convolutional-component.h @@ -24,6 +24,7 @@ #include "nnet3/nnet-component-itf.h" #include "nnet3/natural-gradient-online.h" #include "nnet3/convolution.h" +#include "nnet3/convolution-cudnn.h" #include namespace kaldi { diff --git a/src/nnet3/nnet-simple-component.h b/src/nnet3/nnet-simple-component.h index 11c60f8f352..89ad44f50c7 100644 --- a/src/nnet3/nnet-simple-component.h +++ b/src/nnet3/nnet-simple-component.h @@ -1091,9 +1091,23 @@ class SumGroupComponent: public Component { }; -/// FixedScaleComponent applies a fixed per-element scale; it's similar -/// to the Rescale component in the nnet1 setup (and only needed for nnet1 -/// model conversion). +/** + FixedScaleComponent applies a fixed per-element scale; it's similar + to the Rescale component in the nnet1 setup (and only needed for nnet1 + model conversion). + + Configuration values accepted by this component: + scales A filename, e.g. scales=foo/bar/scales.vec. The file should + contain something readable as a Vector; the text form is like: + [ 0.5 0.5 0.2 ] + dim Only accepted if 'scales' is not set, the dimension of the + scale. This is not very useful any more: scales that are the same + for each dimension can now be captured in Descriptors, e.g. + Scale(2.0, some_component_node). + scale If 'dim' is set, the value to which the scale should be + set (will be a constant). Otherwise it will be random (which is + useful only for testing purposes). +*/ class FixedScaleComponent: public Component { public: FixedScaleComponent() { }