From 2d76ce9ad6648bad9f71318974d871c082c98ff3 Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Sat, 8 Sep 2018 11:06:42 -0700 Subject: [PATCH 01/22] WIP: Add CUDNN as an optional dependency to cu-device.{cc,h} Currently, this assumes that cudnn is installed in /usr/local/cuda/lib64 and /usr/local/cuda/include, because that's what my current machine has it installed. This can be cleaned up later. --- src/cudamatrix/cu-common.h | 13 +++++++++++++ src/cudamatrix/cu-device.cc | 25 +++++++++++++++++++++++-- src/cudamatrix/cu-device.h | 12 ++++++++++++ src/makefiles/cuda_32bit.mk | 8 ++++---- src/makefiles/cuda_64bit.mk | 10 ++++++---- 5 files changed, 58 insertions(+), 10 deletions(-) diff --git a/src/cudamatrix/cu-common.h b/src/cudamatrix/cu-common.h index 7446a76bf93..de1ce42ac10 100644 --- a/src/cudamatrix/cu-common.h +++ b/src/cudamatrix/cu-common.h @@ -74,6 +74,19 @@ } \ } +#if HAVE_CUDNN == 1 +#include + +#define CUDNN_SAFE_CALL(fun) \ +do { \ + cudnnStatus_t ret; \ + if ((ret = (fun)) != CUDNN_STATUS_SUCCESS) { \ + KALDI_ERR << "cudnnStatus_t " << ret << " : \"" << cudnnGetErrorString(ret) \ + << "\" returned from '" << #fun << "'"; \ + } \ +} while(0) +#endif // HAVE_CUDNN == 1 + namespace kaldi { diff --git a/src/cudamatrix/cu-device.cc b/src/cudamatrix/cu-device.cc index 6abba7fba6b..d8e26486c36 100644 --- a/src/cudamatrix/cu-device.cc +++ b/src/cudamatrix/cu-device.cc @@ -4,6 +4,7 @@ // 2013 Lucas Ondel // 2013-2015 Johns Hopkins University (author: Daniel Povey) // 2015 Guoguo Chen +// 2018 Daniel Galvez // See ../../COPYING for clarification regarding multiple authors // @@ -26,6 +27,9 @@ #include #include #include +#if HAVE_CUDNN == 1 +#include +#endif // HAVE_CUDNN == 1 #include #include @@ -92,7 +96,8 @@ void CuDevice::Initialize() { // // (2) in threads created by the user, as soon as someone calls something that // might potentially use the GPU, via CuDevice()::Instantiate(). - // If device_id_ is >= 0, this will create the cuBLAS and cuSparse handles. + // If device_id_ is >= 0, this will create the cuBLAS, cuSparse, cuDNN + // handles. KALDI_ASSERT(!initialized_); initialized_ = true; if (device_id_ == -1) { @@ -113,6 +118,11 @@ void CuDevice::Initialize() { // Initialize the cuSPARSE library CUSPARSE_SAFE_CALL(cusparseCreate(&cusparse_handle_)); CUSPARSE_SAFE_CALL(cusparseSetStream(cusparse_handle_, cudaStreamPerThread)); + +#if HAVE_CUDNN == 1 + CUDNN_SAFE_CALL(cudnnCreate(&cudnn_handle_)); + CUDNN_SAFE_CALL(cudnnSetStream(cudnn_handle_, cudaStreamPerThread)); +#endif // HAVE_CUDNN == 1 } } @@ -249,6 +259,11 @@ void CuDevice::FinalizeActiveGpu() { CUSPARSE_SAFE_CALL(cusparseCreate(&cusparse_handle_)); CUSPARSE_SAFE_CALL(cusparseSetStream(cusparse_handle_, cudaStreamPerThread)); +#if HAVE_CUDNN == 1 + CUDNN_SAFE_CALL(cudnnCreate(&cudnn_handle_)); + CUDNN_SAFE_CALL(cudnnSetStream(cudnn_handle_, cudaStreamPerThread)); +#endif // HAVE_CUDNN == 1 + // Notify the user which GPU is being userd. char name[128]; DeviceGetName(name,128, device_id); @@ -511,7 +526,8 @@ CuDevice::CuDevice(): initialized_(false), device_id_copy_(-1), cublas_handle_(NULL), - cusparse_handle_(NULL) { + cusparse_handle_(NULL), + cudnn_handle_(NULL) { } @@ -520,6 +536,11 @@ CuDevice::~CuDevice() { CUBLAS_SAFE_CALL(cublasDestroy(cublas_handle_)); if (cusparse_handle_) CUSPARSE_SAFE_CALL(cusparseDestroy(cusparse_handle_)); + if (cudnn_handle_) { +#if HAVE_CUDNN == 1 + CUDNN_SAFE_CALL(cudnnDestroy(cudnn_handle_)); +#endif // HAVE_CUDNN == 1 + } } diff --git a/src/cudamatrix/cu-device.h b/src/cudamatrix/cu-device.h index 4967ccb5045..6b7f60814fc 100644 --- a/src/cudamatrix/cu-device.h +++ b/src/cudamatrix/cu-device.h @@ -2,6 +2,7 @@ // Copyright 2009-2012 Karel Vesely // 2012-2015 Johns Hopkins University (author: Daniel Povey) +// 2018 Daniel Galvez // See ../../COPYING for clarification regarding multiple authors // @@ -35,6 +36,11 @@ #include "base/timer.h" #include "cudamatrix/cu-allocator.h" +// Forward declare the cudnnHandle_t found in cudnn.h so that we don't +// need to #include . This allows us to make cudnn an +// optional dependency. +typedef struct cudnnContext *cudnnHandle_t; + namespace kaldi { class CuTimer; @@ -80,6 +86,7 @@ class CuDevice { inline cublasHandle_t GetCublasHandle() { return cublas_handle_; } inline cusparseHandle_t GetCusparseHandle() { return cusparse_handle_; } + inline cudnnHandle_t GetCudnnHandle() { return cudnn_handle_; } // We provide functions Malloc(), MallocPitch() and Free() which replace // cudaMalloc(), cudaMallocPitch() and cudaFree(). Their function is to cache @@ -184,6 +191,7 @@ class CuDevice { /// (i.e. from outside the class), call this only if Enabled() returns true. bool IsComputeExclusive(); + // Shouldn't this be a private constructor? CuDevice(); ~CuDevice(); @@ -271,6 +279,8 @@ class CuDevice { cusparseHandle_t cusparse_handle_; + cudnnHandle_t cudnn_handle_; + }; // class CuDevice @@ -289,6 +299,8 @@ inline cublasHandle_t GetCublasHandle() { return CuDevice::Instantiate().GetCubl // A more convenient way to get the handle to use cuSPARSE APIs. inline cusparseHandle_t GetCusparseHandle() { return CuDevice::Instantiate().GetCusparseHandle(); } +inline cudnnHandle_t GetCudnnHandle() { return CuDevice::Instantiate().GetCudnnHandle(); } + } // namespace kaldi diff --git a/src/makefiles/cuda_32bit.mk b/src/makefiles/cuda_32bit.mk index f6ddfb6d80f..0e551c038db 100644 --- a/src/makefiles/cuda_32bit.mk +++ b/src/makefiles/cuda_32bit.mk @@ -6,9 +6,9 @@ $(error CUDATKDIR not defined.) endif CUDA_INCLUDE= -I$(CUDATKDIR)/include -CUDA_FLAGS = -g -Xcompiler -fPIC --verbose --machine 32 -DHAVE_CUDA \ - -ccbin $(CXX) -DKALDI_DOUBLEPRECISION=$(DOUBLE_PRECISION) \ +CUDA_FLAGS = -g -Xcompiler -fPIC --verbose --machine 32 -DHAVE_CUDA=1 \ + -DHAVE_CUDNN=1 -ccbin $(CXX) -DKALDI_DOUBLEPRECISION=$(DOUBLE_PRECISION) \ -DCUDA_API_PER_THREAD_DEFAULT_STREAM -CXXFLAGS += -DHAVE_CUDA -I$(CUDATKDIR)/include +CXXFLAGS += -DHAVE_CUDA=1 -DHAVE_CUDNN=1 -I$(CUDATKDIR)/include LDFLAGS += -L$(CUDATKDIR)/lib -Wl,-rpath=$(CUDATKDIR)/lib -LDLIBS += -lcublas -lcusparse -lcudart -lcurand #LDLIBS : The libs are loaded later than static libs in implicit rule +LDLIBS += -lcudnn -lcublas -lcusparse -lcudart -lcurand #LDLIBS : The libs are loaded later than static libs in implicit rule diff --git a/src/makefiles/cuda_64bit.mk b/src/makefiles/cuda_64bit.mk index 6a428e7391f..a749415cabb 100644 --- a/src/makefiles/cuda_64bit.mk +++ b/src/makefiles/cuda_64bit.mk @@ -5,10 +5,12 @@ ifndef CUDATKDIR $(error CUDATKDIR not defined.) endif +# TODO: Clean this up to make cudnn an optional dependency CUDA_INCLUDE= -I$(CUDATKDIR)/include -CUDA_FLAGS = -g -Xcompiler -fPIC --verbose --machine 64 -DHAVE_CUDA \ - -ccbin $(CXX) -DKALDI_DOUBLEPRECISION=$(DOUBLE_PRECISION) \ +CUDA_FLAGS = -g -Xcompiler -fPIC --verbose --machine 64 -DHAVE_CUDA=1 + -DHAVE_CUDNN=1 -ccbin $(CXX) \ + -DKALDI_DOUBLEPRECISION=$(DOUBLE_PRECISION) \ -DCUDA_API_PER_THREAD_DEFAULT_STREAM -CXXFLAGS += -DHAVE_CUDA -I$(CUDATKDIR)/include +CXXFLAGS += -DHAVE_CUDA=1 -DHAVE_CUDNN=1 -I$(CUDATKDIR)/include CUDA_LDFLAGS += -L$(CUDATKDIR)/lib64 -Wl,-rpath,$(CUDATKDIR)/lib64 -CUDA_LDLIBS += -lcublas -lcusparse -lcudart -lcurand #LDLIBS : The libs are loaded later than static libs in implicit rule +CUDA_LDLIBS += -lcudnn -lcublas -lcusparse -lcudart -lcurand #LDLIBS : The libs are loaded later than static libs in implicit rule From 0a00e7337a511cc9b0d46d9c57ab73cc7b9ceae2 Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Sun, 9 Sep 2018 15:02:52 -0400 Subject: [PATCH 02/22] Download CUDNN, get this to build on CLSP --- src/configure | 14 ++++++++++++++ src/makefiles/cuda_64bit.mk | 22 ++++++++++++++++------ src/nnet3/Makefile | 2 +- tools/Makefile | 8 +++++++- 4 files changed, 38 insertions(+), 8 deletions(-) diff --git a/src/configure b/src/configure index 9055cd0ba02..8a417ded9f6 100755 --- a/src/configure +++ b/src/configure @@ -65,6 +65,7 @@ Configuration options: --shared Build and link against shared libraries [default=no] --use-cuda Build with CUDA [default=yes] --cudatk-dir=DIR CUDA toolkit directory + --cudnn-dir=DIR CUDNN installation directory --double-precision Build with BaseFloat set to double if yes [default=no], mostly useful for testing purposes. --static-fst Build with static OpenFst libraries [default=no] @@ -437,6 +438,7 @@ function configure_cuda { echo CUDA = true >> kaldi.mk echo CUDATKDIR = $CUDATKDIR >> kaldi.mk echo "CUDA_ARCH = $CUDA_ARCH" >> kaldi.mk + echo CUDNNDIR = $CUDNNDIR >> kaldi.mk echo >> kaldi.mk # 64bit/32bit? We do not support cross compilation with CUDA so, use direct calls to uname -m here @@ -459,6 +461,15 @@ function configure_cuda { echo "and cuda toolkit, try using --cudatk-dir=... option. Note: this is" echo "only relevant for neural net experiments" fi + + if [ ! -z $CUDNNDIR ]; then + if [ ! -f $CUDNNDIR/lib64/libcudnn.so ] | + [ ! -f $CUDNNDIR/include/cudnn.h ]; then + echo "CUDNNDIR(=$CUDNNDIR) invalid!" + fi + + + fi } function linux_configure_speex { @@ -974,6 +985,9 @@ do --cudatk-dir=*) CUDATKDIR=`read_dirname $1`; shift ;; #CUDA is used in src/cudamatrix and src/nnet{,bin} only + --cudnn-dir=*) + CUDNNDIR=`read_dirname $1`; + shift ;; --fst-version=*) OPENFST_VER=`expr "X$1" : '[^=]*=\(.*\)'`; shift;; diff --git a/src/makefiles/cuda_64bit.mk b/src/makefiles/cuda_64bit.mk index a749415cabb..e77d70c8109 100644 --- a/src/makefiles/cuda_64bit.mk +++ b/src/makefiles/cuda_64bit.mk @@ -5,12 +5,22 @@ ifndef CUDATKDIR $(error CUDATKDIR not defined.) endif -# TODO: Clean this up to make cudnn an optional dependency -CUDA_INCLUDE= -I$(CUDATKDIR)/include -CUDA_FLAGS = -g -Xcompiler -fPIC --verbose --machine 64 -DHAVE_CUDA=1 - -DHAVE_CUDNN=1 -ccbin $(CXX) \ +# Order matters here. We must tell the compiler to search +# $(CUDNNDIR)/lib64 before $(CUDATKDIR)/lib64 because the CUDNN .deb +# files install cudnn to /usr/local/cuda/lib64, which would overshadow +# the user-specified $(CUDNNDIR) +ifdef CUDNNDIR +CUDA_INCLUDE += -I$(CUDNNDIR)/include +CUDA_FLAGS += -DHAVE_CUDNN=1 +CXXFLAGS += -I$(CUDNNDIR)/include -DHAVE_CUDNN=1 +CUDA_LDFLAGS += -L$(CUDNNDIR)/lib64 -Wl,-rpath,$(CUDNNDIR)/lib64 +CUDA_LDLIBS += -lcudnn +endif +CUDA_INCLUDE += -I$(CUDATKDIR)/include +CUDA_FLAGS += -g -Xcompiler -fPIC --verbose --machine 64 -DHAVE_CUDA=1 \ + -ccbin $(CXX) \ -DKALDI_DOUBLEPRECISION=$(DOUBLE_PRECISION) \ -DCUDA_API_PER_THREAD_DEFAULT_STREAM -CXXFLAGS += -DHAVE_CUDA=1 -DHAVE_CUDNN=1 -I$(CUDATKDIR)/include +CXXFLAGS += -DHAVE_CUDA=1 -I$(CUDATKDIR)/include CUDA_LDFLAGS += -L$(CUDATKDIR)/lib64 -Wl,-rpath,$(CUDATKDIR)/lib64 -CUDA_LDLIBS += -lcudnn -lcublas -lcusparse -lcudart -lcurand #LDLIBS : The libs are loaded later than static libs in implicit rule +CUDA_LDLIBS += -lcublas -lcusparse -lcudart -lcurand #LDLIBS : The libs are loaded later than static libs in implicit rule diff --git a/src/nnet3/Makefile b/src/nnet3/Makefile index 135853cadc3..7984291360c 100644 --- a/src/nnet3/Makefile +++ b/src/nnet3/Makefile @@ -29,7 +29,7 @@ OBJFILES = nnet-common.o nnet-compile.o nnet-component-itf.o \ nnet-discriminative-diagnostics.o \ discriminative-training.o nnet-discriminative-training.o \ nnet-compile-looped.o decodable-simple-looped.o \ - decodable-online-looped.o convolution.o \ + decodable-online-looped.o convolution.o convolution-cudnn.o \ nnet-convolutional-component.o attention.o \ nnet-attention-component.o nnet-tdnn-component.o diff --git a/tools/Makefile b/tools/Makefile index 1d62e1a3765..cab8245ee2e 100644 --- a/tools/Makefile +++ b/tools/Makefile @@ -18,7 +18,7 @@ ifeq ("$(shell expr $(OPENFST_VER_NUM) \< 10600)","1") Supported versions: >= 1.6.0) endif -all: check_required_programs sph2pipe sclite openfst +all: check_required_programs sph2pipe sclite openfst cudnn @echo -e "\n\n" @echo "Warning: IRSTLM is not installed by default anymore. If you need IRSTLM" @echo "Warning: use the script extras/install_irstlm.sh" @@ -149,3 +149,9 @@ openblas_compiled: cd OpenBLAS; sed 's:# FCOMMON_OPT = -frecursive:FCOMMON_OPT = -frecursive:' < Makefile.rule >tmp && mv tmp Makefile.rule # $(MAKE) PREFIX=`pwd`/OpenBLAS/install FC=gfortran $(fortran_opt) DEBUG=1 USE_THREAD=1 NUM_THREADS=64 -C OpenBLAS all install $(MAKE) PREFIX=`pwd`/OpenBLAS/install FC=gfortran $(fortran_opt) DEBUG=1 USE_THREAD=0 -C OpenBLAS all install + +cudnn: + wget -T 10 -t 3 http://developer.download.nvidia.com/compute/redist/cudnn/v7.1.2/cudnn-9.1-linux-x64-v7.1.tgz -O cudnn-9.1-linux-x64-v7.1.tgz + -echo "c61000ed700bc5a009bc2e135bbdf736c9743212b2174a2fc9018a66cc0979ec cudnn-9.1-linux-x64-v7.1.tgz" | sha256sum -c + -mkdir -p cudnn/ + tar --no-same-owner -xzf cudnn-9.1-linux-x64-v7.1.tgz -C cudnn/ From 224d4da9d39dcad5b6c1dcbe773f9009dce776ee Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Sun, 9 Sep 2018 15:03:08 -0400 Subject: [PATCH 03/22] Initial draft of CUDNN 2d convolution implementation. Double check that we correctly transpose height and width everywhere! --- src/nnet3/convolution-cudnn.cc | 304 +++++++++++++++++++++++++++++++++ src/nnet3/convolution-cudnn.h | 126 ++++++++++++++ 2 files changed, 430 insertions(+) create mode 100644 src/nnet3/convolution-cudnn.cc create mode 100644 src/nnet3/convolution-cudnn.h diff --git a/src/nnet3/convolution-cudnn.cc b/src/nnet3/convolution-cudnn.cc new file mode 100644 index 00000000000..c4f904e92d2 --- /dev/null +++ b/src/nnet3/convolution-cudnn.cc @@ -0,0 +1,304 @@ +// nnet3/convolution-cudnn.cc + +// Copyright 2018 Daniel Galvez + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "nnet3/convolution-cudnn.h" + +namespace kaldi { +namespace nnet3 { +namespace cudnn { + + +namespace { + const BaseFloat ONE(1.0); + const BaseFloat ZERO(0.0); +} + + ConvolutionComputation:: + ConvolutionComputation(int32 num_channels_out, int32 num_channels_in, + int32 filter_height, int32 filter_width, + int32 filter_stride_height, int32 filter_stride_width, + // dilation? + int32 filter_dilation_height, + int32 filter_dilation_width, + int32 num_images, + int32 input_image_height, int32 input_image_width, + int32 zero_padding_height, int32 zero_padding_width) { + CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&input_desc_)); + CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&output_desc_)); + CUDNN_SAFE_CALL(cudnnCreateFilterDescriptor(¶ms_desc_)); + CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&bias_desc_)); + CUDNN_SAFE_CALL(cudnnCreateConvolutionDescriptor(&conv_desc_)); + CUDNN_SAFE_CALL(cudnnCreateActivationDescriptor(&activation_desc_)); + + CUDNN_SAFE_CALL( + cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NHWC, + CUDNN_DATA_FLOAT, num_images, + num_channels_in, input_image_width, + input_image_height)); + + int32 out_kaldi_height_cudnn_width = OutputImageHeight(); + int32 out_kaldi_width_cudnn_height = OutputImageWidth(); + CUDNN_SAFE_CALL( + cudnnSetTensor4dDescriptor(output_desc_, CUDNN_TENSOR_NHWC, + CUDNN_DATA_FLOAT, num_images, + num_channels_in, out_kaldi_width_cudnn_height, + out_kaldi_height_cudnn_width)); + CUDNN_SAFE_CALL( + cudnnSetFilter4dDescriptor(params_desc_, CUDNN_DATA_FLOAT, + CUDNN_TENSOR_NHWC, num_channels_out, + num_channels_in, filter_width, filter_height)); + int32 bias_stride = 1; + CUDNN_SAFE_CALL( + cudnnSetTensorNdDescriptor(bias_desc_, CUDNN_DATA_FLOAT, 1, + &num_channels_out, &bias_stride)); + CUDNN_SAFE_CALL( + cudnnSetConvolution2dDescriptor(conv_desc_, + zero_padding_width, zero_padding_height, + filter_stride_width, filter_stride_height, + filter_dilation_width, filter_dilation_height, + CUDNN_CROSS_CORRELATION, // TODO: Double check this! + CUDNN_DATA_FLOAT)); + + const double DONT_CARE = 0; + CUDNN_SAFE_CALL( + cudnnSetActivationDescriptor(activation_desc_, CUDNN_ACTIVATION_IDENTITY, + CUDNN_PROPAGATE_NAN, DONT_CARE)); + + int32 requested_algo_count, returned_algo_count; + CUDNN_SAFE_CALL(cudnnGetConvolutionForwardAlgorithmMaxCount( + CuDevice::Instantiate().GetCudnnHandle(), &requested_algo_count)); + + cudnnConvolutionFwdAlgoPerf_t *forward_results = + new cudnnConvolutionFwdAlgoPerf_t[requested_algo_count]; + CUDNN_SAFE_CALL(cudnnFindConvolutionForwardAlgorithm( + CuDevice::Instantiate().GetCudnnHandle(), + input_desc_, + params_desc_, + conv_desc_, + output_desc_, + requested_algo_count, + &returned_algo_count, + forward_results)); + + KALDI_ASSERT(returned_algo_count > 0); + const cudnnConvolutionFwdAlgoPerf_t& best_forward = forward_results[0]; + fwd_algo_ = best_forward.algo; + delete [] forward_results; + + CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( + CuDevice::Instantiate().GetCudnnHandle(), &requested_algo_count)); + cudnnConvolutionBwdFilterAlgoPerf_t *backward_filter_results = + new cudnnConvolutionBwdFilterAlgoPerf_t[requested_algo_count]; + CUDNN_SAFE_CALL(cudnnFindConvolutionBackwardFilterAlgorithm( + CuDevice::Instantiate().GetCudnnHandle(), + input_desc_, + output_desc_, + conv_desc_, + params_desc_, + requested_algo_count, + &returned_algo_count, + backward_filter_results)); + KALDI_ASSERT(returned_algo_count > 0); + const cudnnConvolutionBwdFilterAlgoPerf_t& best_backward_filter = + backward_filter_results[0]; + bwd_filter_algo_ = best_backward_filter.algo; + delete [] backward_filter_results; + + CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( + CuDevice::Instantiate().GetCudnnHandle(), &requested_algo_count)); + cudnnConvolutionBwdDataAlgoPerf_t *backward_data_results = + new cudnnConvolutionBwdDataAlgoPerf_t[requested_algo_count]; + CUDNN_SAFE_CALL(cudnnFindConvolutionBackwardDataAlgorithm( + CuDevice::Instantiate().GetCudnnHandle(), + params_desc_, + output_desc_, + conv_desc_, + input_desc_, + requested_algo_count, + &returned_algo_count, + backward_data_results)); + KALDI_ASSERT(returned_algo_count > 0); + const cudnnConvolutionBwdDataAlgoPerf_t& best_backward_data = + backward_data_results[0]; + bwd_data_algo_ = best_backward_data.algo; + delete [] backward_data_results; + } + + ConvolutionComputation::~ConvolutionComputation() { + CUDNN_SAFE_CALL(cudnnDestroyTensorDescriptor(input_desc_)); + CUDNN_SAFE_CALL(cudnnDestroyTensorDescriptor(output_desc_)); + CUDNN_SAFE_CALL(cudnnDestroyFilterDescriptor(params_desc_)); + CUDNN_SAFE_CALL(cudnnDestroyTensorDescriptor(bias_desc_)); + CUDNN_SAFE_CALL(cudnnDestroyConvolutionDescriptor(conv_desc_)); + CUDNN_SAFE_CALL(cudnnDestroyActivationDescriptor(activation_desc_)); + } + + int32 ConvolutionComputation::OutputImageHeight() const { + int32 unused; + int32 kaldi_height_cudnn_width; + CUDNN_SAFE_CALL( + cudnnGetConvolution2dForwardOutputDim(conv_desc_, input_desc_, + params_desc_, + &unused, &unused, + &kaldi_height_cudnn_width, + &unused)); + return kaldi_height_cudnn_width; + } + + int32 ConvolutionComputation::OutputImageWidth() const { + int32 unused; + int32 kaldi_width_cudnn_height; + CUDNN_SAFE_CALL( + cudnnGetConvolution2dForwardOutputDim(conv_desc_, input_desc_, + params_desc_, + &unused, &unused, + &unused, + &kaldi_width_cudnn_height)); + return kaldi_width_cudnn_height; + } + + size_t ConvolutionComputation::TempSpaceRequiredForward() const { + size_t workspace_size_bytes; + CUDNN_SAFE_CALL(cudnnGetConvolutionForwardWorkspaceSize( + CuDevice::Instantiate().GetCudnnHandle(), + input_desc_, + params_desc_, + conv_desc_, + output_desc_, + fwd_algo_, + &workspace_size_bytes)); + return workspace_size_bytes; + } + + size_t ConvolutionComputation::TempSpaceRequiredBackwardData() const { + size_t workspace_size_bytes; + CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize( + CuDevice::Instantiate().GetCudnnHandle(), + params_desc_, + output_desc_, + conv_desc_, + input_desc_, + bwd_data_algo_, + &workspace_size_bytes)); + return workspace_size_bytes; + } + + + size_t ConvolutionComputation::TempSpaceRequiredBackwardFilter() const { + size_t workspace_size_bytes; + CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize( + CuDevice::Instantiate().GetCudnnHandle(), + input_desc_, + output_desc_, + conv_desc_, + params_desc_, + bwd_filter_algo_, + &workspace_size_bytes)); + return workspace_size_bytes; + } + + + + void ConvolutionComputation:: + ConvolveForward(const CuMatrixBase &input, + const CuMatrixBase ¶ms, + const CuVectorBase &bias, + CuVectorBase *temp_space, + CuMatrixBase *output) const { + CUDNN_SAFE_CALL(cudnnConvolutionBiasActivationForward( + CuDevice::Instantiate().GetCudnnHandle(), + &ONE, + input_desc_, + input.Data(), + params_desc_, + params.Data(), + conv_desc_, + fwd_algo_, + temp_space->Data(), + temp_space->Dim() * sizeof(BaseFloat), + &ZERO, + output_desc_, + output->Data(), + bias_desc_, + bias.Data(), + activation_desc_, + output_desc_, + output->Data())); + } + + void ConvolutionComputation:: + ConvolveBackwardData(const CuMatrixBase ¶ms, + const CuMatrixBase &output_deriv, + CuVectorBase *temp, + CuMatrixBase *input_deriv) const { + CUDNN_SAFE_CALL(cudnnConvolutionBackwardData( + CuDevice::Instantiate().GetCudnnHandle(), + &ONE, + params_desc_, + params.Data(), + output_desc_, + output_deriv.Data(), + conv_desc_, + bwd_data_algo_, + temp->Data(), + temp->Dim() * sizeof(BaseFloat), + &ZERO, + input_desc_, + input_deriv->Data())); + } + + void ConvolutionComputation:: + ConvolveBackwardParams(const CuMatrixBase &output_deriv, + const CuMatrixBase &input, + BaseFloat alpha, + CuVectorBase *temp, + CuMatrixBase *params_deriv) const { + CUDNN_SAFE_CALL(cudnnConvolutionBackwardFilter( + CuDevice::Instantiate().GetCudnnHandle(), + &alpha, + input_desc_, + input.Data(), + output_desc_, + output_deriv.Data(), + conv_desc_, + bwd_filter_algo_, + temp->Data(), + temp->Dim() * sizeof(BaseFloat), + &ONE, + params_desc_, + params_deriv->Data())); + } + + void ConvolutionComputation:: + ConvolveBackwardBias(const CuMatrixBase &output_deriv, + BaseFloat alpha, + CuVectorBase *bias_deriv) const { + CUDNN_SAFE_CALL(cudnnConvolutionBackwardBias( + CuDevice::Instantiate().GetCudnnHandle(), + &alpha, + output_desc_, + output_deriv.Data(), + &ONE, + bias_desc_, + bias_deriv->Data())); + } + +} // namespace cudnn +} // namespace nnet3 +} // namespace kaldi diff --git a/src/nnet3/convolution-cudnn.h b/src/nnet3/convolution-cudnn.h new file mode 100644 index 00000000000..be4e2bbfb84 --- /dev/null +++ b/src/nnet3/convolution-cudnn.h @@ -0,0 +1,126 @@ +// nnet3/convolution-cudnn.h + +// Copyright 2018 Daniel Galvez + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_NNET3_NNET_CUDNN_CONVOLUTION_H_ +#define KALDI_NNET3_NNET_CUDNN_CONVOLUTION_H_ + + +#include "base/kaldi-common.h" +#include "matrix/matrix-lib.h" +#include "nnet3/convolution.h" + +#include + +namespace kaldi { +namespace nnet3 { +namespace cudnn { + +class ConvolutionComputation { +public: + ConvolutionComputation(int32 num_channels_out, int32 num_channels_in, + int32 filter_height, int32 filter_width, + int32 filter_stride_height, int32 filter_stride_width, + // dilation? + int32 filter_dilation_height, + int32 filter_dilation_width, + int32 num_images, + int32 input_image_height, int32 input_image_width, + int32 zero_padding_height, int32 zero_padding_width); + ~ConvolutionComputation(); + int32 OutputImageHeight() const; + int32 OutputImageWidth() const; + + /** + * Returns the size of the workspace required for each stage, in + * bytes (not 32-bit words). + */ + size_t TempSpaceRequiredForward() const; + size_t TempSpaceRequiredBackwardData() const; + size_t TempSpaceRequiredBackwardFilter() const; + + // Why aren't these const methods? That would make things a lot simpler + void ConvolveForward(const CuMatrixBase &input, + const CuMatrixBase ¶ms, + const CuVectorBase &bias, + CuVectorBase *temp_space, + CuMatrixBase *output) const; + + // Why aren't these const methods? That would make things a lot simpler + void ConvolveBackwardData(const CuMatrixBase ¶ms, + const CuMatrixBase &output_deriv, + CuVectorBase *temp, + CuMatrixBase *input_deriv) const; + + // Why aren't these const methods? That would make things a lot simpler + void ConvolveBackwardParams(const CuMatrixBase &output_deriv, + const CuMatrixBase &input, + BaseFloat alpha, + CuVectorBase *temp, + CuMatrixBase *params_deriv) const; + + // Why aren't these const methods? That would make things a lot simpler + void ConvolveBackwardBias(const CuMatrixBase &output_deriv, + BaseFloat alpha, + CuVectorBase *bias_deriv) const; + +private: + cudnnTensorDescriptor_t input_desc_; + cudnnTensorDescriptor_t output_desc_; + cudnnFilterDescriptor_t params_desc_; + cudnnTensorDescriptor_t bias_desc_; + cudnnConvolutionDescriptor_t conv_desc_; + cudnnActivationDescriptor_t activation_desc_; + + cudnnConvolutionFwdAlgo_t fwd_algo_; + cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo_; + cudnnConvolutionBwdDataAlgo_t bwd_data_algo_; +}; + +/* /\** */ +/* This function does the compilation for a convolution computation; it's */ +/* a wrapper for the functions below, which should not have to be called */ +/* by the end user. */ + +/* @param [in] model The convolution model that this computation is for. */ +/* @param [in] input_indexes The list of Indexes available at the input of */ +/* the computation. */ +/* @param [in] output_indexes The list of Indexes requested to be computed */ +/* at the output of the computation. It is an error if */ +/* all dependencies are not satisfied (specifically: for */ +/* each Index (n,t,x) in 'output_indexes', the Index */ +/* (n,t+time_offset,x) must be present in 'input_indexes' */ +/* for each time_offset in model.required_time_offsets. */ +/* @param [out] computation If non-NULL, the compiled computation will be */ +/* written to this location. */ + +/* *\/ */ +/* void CompileConvolutionComputation( */ +/* const ConvolutionModel& model, */ +/* const std::vector &input_indexes, */ +/* const std::vector &output_indexes, */ +/* const ConvolutionComputationOptions &opts, */ +/* cudnn::ConvolutionComputation *computation, */ +/* std::vector *input_indexes_modified, */ +/* std::vector *output_indexes_modified); */ + +} // namespace cudnn +} // namespace nnet3 +} // namespace kaldi + +#endif // KALDI_NNET3_NNET_CUDNN_CONVOLUTION_H_ From c9806e06704c6c2c66b02cbe7e8283124d497363 Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Sun, 9 Sep 2018 16:49:11 -0400 Subject: [PATCH 04/22] Minor --- src/nnet3/convolution-cudnn.h | 1 + src/nnet3/nnet-convolutional-component.h | 1 + 2 files changed, 2 insertions(+) diff --git a/src/nnet3/convolution-cudnn.h b/src/nnet3/convolution-cudnn.h index be4e2bbfb84..44e6b1f441f 100644 --- a/src/nnet3/convolution-cudnn.h +++ b/src/nnet3/convolution-cudnn.h @@ -80,6 +80,7 @@ class ConvolutionComputation { CuVectorBase *bias_deriv) const; private: + // Need to use the PIMPL idiom so our clients don't need to have cudnn.h cudnnTensorDescriptor_t input_desc_; cudnnTensorDescriptor_t output_desc_; cudnnFilterDescriptor_t params_desc_; diff --git a/src/nnet3/nnet-convolutional-component.h b/src/nnet3/nnet-convolutional-component.h index 279cec321dd..e01ab3a721c 100644 --- a/src/nnet3/nnet-convolutional-component.h +++ b/src/nnet3/nnet-convolutional-component.h @@ -24,6 +24,7 @@ #include "nnet3/nnet-component-itf.h" #include "nnet3/natural-gradient-online.h" #include "nnet3/convolution.h" +#include "nnet3/cudnn-convolution.h" #include namespace kaldi { From 5b1855a6c11619922131ceb77977534f342d0ed0 Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Sat, 15 Sep 2018 16:21:33 -0400 Subject: [PATCH 05/22] [egs] mini librispeech fix for CLSP. b17 disk is no longer writable. --- egs/mini_librispeech/s5/run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/mini_librispeech/s5/run.sh b/egs/mini_librispeech/s5/run.sh index 681859edf8a..3ab9d243ef6 100755 --- a/egs/mini_librispeech/s5/run.sh +++ b/egs/mini_librispeech/s5/run.sh @@ -48,7 +48,7 @@ if [ $stage -le 2 ]; then # spread the mfccs over various machines, as this data-set is quite large. if [[ $(hostname -f) == *.clsp.jhu.edu ]]; then mfcc=$(basename mfccdir) # in case was absolute pathname (unlikely), get basename. - utils/create_split_dir.pl /export/b{07,14,16,17}/$USER/kaldi-data/egs/librispeech/s5/$mfcc/storage \ + utils/create_split_dir.pl /export/b{07,14,16,18}/$USER/kaldi-data/egs/librispeech/s5/$mfcc/storage \ $mfccdir/storage fi From 8e00e4f77a94282235efc892fd7146ff440eb1f3 Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Sat, 15 Sep 2018 16:43:32 -0400 Subject: [PATCH 06/22] [src] Fix Singleton implementation of CuDevice. No longer publicly expose the default constructor. That could allow someone to do something like this: CuDevice device1; CuDevice device2; We're not sure what this would cause, but it probably wouldn't be good. --- src/cudamatrix/cu-device.cc | 1 - src/cudamatrix/cu-device.h | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/cudamatrix/cu-device.cc b/src/cudamatrix/cu-device.cc index d8e26486c36..bb3f31be72b 100644 --- a/src/cudamatrix/cu-device.cc +++ b/src/cudamatrix/cu-device.cc @@ -530,7 +530,6 @@ CuDevice::CuDevice(): cudnn_handle_(NULL) { } - CuDevice::~CuDevice() { if (cublas_handle_) CUBLAS_SAFE_CALL(cublasDestroy(cublas_handle_)); diff --git a/src/cudamatrix/cu-device.h b/src/cudamatrix/cu-device.h index 6b7f60814fc..95c0d48413b 100644 --- a/src/cudamatrix/cu-device.h +++ b/src/cudamatrix/cu-device.h @@ -191,11 +191,10 @@ class CuDevice { /// (i.e. from outside the class), call this only if Enabled() returns true. bool IsComputeExclusive(); - // Shouldn't this be a private constructor? - CuDevice(); - ~CuDevice(); private: + // Default constructor used to initialize this_thread_device_ + CuDevice(); CuDevice(CuDevice&); // Disallow. CuDevice &operator=(CuDevice&); // Disallow. From 5e4d2705ed9bf4a1448dcd7dc99fdfc171d3220c Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Sat, 15 Sep 2018 16:46:10 -0400 Subject: [PATCH 07/22] Small CUDNN fixes --- src/cudamatrix/cu-device.cc | 4 ++-- src/nnet3/nnet-convolutional-component.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/cudamatrix/cu-device.cc b/src/cudamatrix/cu-device.cc index bb3f31be72b..45c74e394e8 100644 --- a/src/cudamatrix/cu-device.cc +++ b/src/cudamatrix/cu-device.cc @@ -535,11 +535,11 @@ CuDevice::~CuDevice() { CUBLAS_SAFE_CALL(cublasDestroy(cublas_handle_)); if (cusparse_handle_) CUSPARSE_SAFE_CALL(cusparseDestroy(cusparse_handle_)); - if (cudnn_handle_) { #if HAVE_CUDNN == 1 + if (cudnn_handle_) { CUDNN_SAFE_CALL(cudnnDestroy(cudnn_handle_)); -#endif // HAVE_CUDNN == 1 } +#endif // HAVE_CUDNN == 1 } diff --git a/src/nnet3/nnet-convolutional-component.h b/src/nnet3/nnet-convolutional-component.h index e01ab3a721c..e98dd0468b4 100644 --- a/src/nnet3/nnet-convolutional-component.h +++ b/src/nnet3/nnet-convolutional-component.h @@ -24,7 +24,7 @@ #include "nnet3/nnet-component-itf.h" #include "nnet3/natural-gradient-online.h" #include "nnet3/convolution.h" -#include "nnet3/cudnn-convolution.h" +#include "nnet3/convolution-cudnn.h" #include namespace kaldi { From 0387a0cf229a094263320debb5e585d8eff6f9d4 Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Sun, 30 Sep 2018 18:38:49 -0400 Subject: [PATCH 08/22] Fix implementation's height-width switching. Document data formats expected of each member function. --- src/nnet3/convolution-cudnn.cc | 54 +++++++------- src/nnet3/convolution-cudnn.h | 131 +++++++++++++++++++++++---------- 2 files changed, 123 insertions(+), 62 deletions(-) diff --git a/src/nnet3/convolution-cudnn.cc b/src/nnet3/convolution-cudnn.cc index c4f904e92d2..723540a1c66 100644 --- a/src/nnet3/convolution-cudnn.cc +++ b/src/nnet3/convolution-cudnn.cc @@ -33,7 +33,6 @@ namespace { ConvolutionComputation(int32 num_channels_out, int32 num_channels_in, int32 filter_height, int32 filter_width, int32 filter_stride_height, int32 filter_stride_width, - // dilation? int32 filter_dilation_height, int32 filter_dilation_width, int32 num_images, @@ -47,11 +46,24 @@ namespace { CUDNN_SAFE_CALL(cudnnCreateActivationDescriptor(&activation_desc_)); CUDNN_SAFE_CALL( - cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NHWC, - CUDNN_DATA_FLOAT, num_images, - num_channels_in, input_image_width, - input_image_height)); + cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NHWC, + CUDNN_DATA_FLOAT, num_images, + num_channels_in, input_image_width, + input_image_height)); + CUDNN_SAFE_CALL( + cudnnSetConvolution2dDescriptor(conv_desc_, + zero_padding_width, zero_padding_height, + filter_stride_width, filter_stride_height, + filter_dilation_width, filter_dilation_height, + CUDNN_CROSS_CORRELATION, // TODO: Double check this! + CUDNN_DATA_FLOAT)); + CUDNN_SAFE_CALL( + cudnnSetFilter4dDescriptor(params_desc_, CUDNN_DATA_FLOAT, + CUDNN_TENSOR_NCHW, num_channels_out, + num_channels_in, filter_width, filter_height)); + // These two member functions depend only on input_desc_, + // conv_desc_, and params_desc_, so they are safe to call now. int32 out_kaldi_height_cudnn_width = OutputImageHeight(); int32 out_kaldi_width_cudnn_height = OutputImageWidth(); CUDNN_SAFE_CALL( @@ -59,21 +71,10 @@ namespace { CUDNN_DATA_FLOAT, num_images, num_channels_in, out_kaldi_width_cudnn_height, out_kaldi_height_cudnn_width)); - CUDNN_SAFE_CALL( - cudnnSetFilter4dDescriptor(params_desc_, CUDNN_DATA_FLOAT, - CUDNN_TENSOR_NHWC, num_channels_out, - num_channels_in, filter_width, filter_height)); - int32 bias_stride = 1; + const int32 bias_stride[] = {1}; CUDNN_SAFE_CALL( cudnnSetTensorNdDescriptor(bias_desc_, CUDNN_DATA_FLOAT, 1, - &num_channels_out, &bias_stride)); - CUDNN_SAFE_CALL( - cudnnSetConvolution2dDescriptor(conv_desc_, - zero_padding_width, zero_padding_height, - filter_stride_width, filter_stride_height, - filter_dilation_width, filter_dilation_height, - CUDNN_CROSS_CORRELATION, // TODO: Double check this! - CUDNN_DATA_FLOAT)); + &num_channels_out, bias_stride)); const double DONT_CARE = 0; CUDNN_SAFE_CALL( @@ -96,7 +97,8 @@ namespace { &returned_algo_count, forward_results)); - KALDI_ASSERT(returned_algo_count > 0); + KALDI_ASSERT(returned_algo_count > 0 && + "No algorithms were returned by CUDNN."); const cudnnConvolutionFwdAlgoPerf_t& best_forward = forward_results[0]; fwd_algo_ = best_forward.algo; delete [] forward_results; @@ -114,7 +116,8 @@ namespace { requested_algo_count, &returned_algo_count, backward_filter_results)); - KALDI_ASSERT(returned_algo_count > 0); + KALDI_ASSERT(returned_algo_count > 0 && + "No algorithms were returned by CUDNN."); const cudnnConvolutionBwdFilterAlgoPerf_t& best_backward_filter = backward_filter_results[0]; bwd_filter_algo_ = best_backward_filter.algo; @@ -133,7 +136,8 @@ namespace { requested_algo_count, &returned_algo_count, backward_data_results)); - KALDI_ASSERT(returned_algo_count > 0); + KALDI_ASSERT(returned_algo_count > 0 && + "No algorithms were returned by CUDNN."); const cudnnConvolutionBwdDataAlgoPerf_t& best_backward_data = backward_data_results[0]; bwd_data_algo_ = best_backward_data.algo; @@ -156,8 +160,8 @@ namespace { cudnnGetConvolution2dForwardOutputDim(conv_desc_, input_desc_, params_desc_, &unused, &unused, - &kaldi_height_cudnn_width, - &unused)); + &unused, + &kaldi_height_cudnn_width)); return kaldi_height_cudnn_width; } @@ -168,8 +172,8 @@ namespace { cudnnGetConvolution2dForwardOutputDim(conv_desc_, input_desc_, params_desc_, &unused, &unused, - &unused, - &kaldi_width_cudnn_height)); + &kaldi_width_cudnn_height, + &unused)); return kaldi_width_cudnn_height; } diff --git a/src/nnet3/convolution-cudnn.h b/src/nnet3/convolution-cudnn.h index 44e6b1f441f..e6c37667f55 100644 --- a/src/nnet3/convolution-cudnn.h +++ b/src/nnet3/convolution-cudnn.h @@ -25,18 +25,76 @@ #include "matrix/matrix-lib.h" #include "nnet3/convolution.h" +// TODO: Consider forward declaring types like +// cudnnTensorDescriptor_t, so that this header file doesn't depend on +// cudnn.h #include namespace kaldi { namespace nnet3 { namespace cudnn { -class ConvolutionComputation { +class ConvolutionComputation final { public: + // Represents structural information about a convolution computation, + // with filters, padding, striding, inputs and outputs of a specified size. The same interface + // is usable on both GPU and CPU. You create this object only after you know the + // number of images and input and output sizes, and it will be stored as part of + // a NnetComputation (i.e. a compiled computation) and re-used between different + // minibatches. This object is lightweight. + // + // In the following docstrings, consider: + // N to be equivalent to num_images + // C to be equivalent to num_channels_in + // K to be equivalent to num_channels_out + // H to be equivalent to input_image_height, or filter_height, + // depending on context + // W to be equivalent to input_image_width, or filter_width, + // depending on context + // + // @param [in] num_channels_out Number of output channels, e.g. 64. + // @param [in] num_channels_in Number of input channels, e.g. 32. + // @param [in] filter_height Height of filter patch, e.g. 3 (for 3x3 kernel). Corresponds + // to the 'frequency' dimension in normal speech applications, or + // height in OCR applications. + // @param [in] filter_width Width of filter patch, e.g. 3 (for 3x3 kernel). Corresponds + // to the 'time' dimension in normal speech applications. + // @param [in] filter_stride_height Filter stride in the height ('frequency') dimension. + // Will normally be 1 in speech and OCR applications. + // @param [in] filter_stride_width Filter stride in the width ('time') dimension. + // Will usually be 1 in most layers, but may be 2 or 3 if + // we are doing subsampling on this layer (e.g. in + // reduced-frame-rate models like chain models). + // @param [in] filter_dilation_height Filter dilation in the height ('frequency') + // dimension. Equals the stride, in the input image, of + // individual elements of the filter patch. Will + // normally be 1. + // @param [in] filter_dilation_width Filter dilation in the width ('time') + // dimension. Will normally be 1, but could + // be more than one if, for instance, you have components + // with time-stride > 1 which for some reason are required + // to be evaluated on every frame. + // @param [in] num_images The number of images we are processing, generally + // equal to the minibatch size. + // @param [in] input_image_height The height of the input images. Corresponds to + // the number of frequency bins, in speech applications. + // @param [in] input_image_width The width of the input images. Corresponds to + // the number of time frames on the input, in speech + // applications. + // @param [in] zero_padding_height The number of pixels that we zero-pad with on + // the bottom, and on the top, of the image (the + // frequency dimension, in speech applications). Would + // be 1, for instance, if you are using a 3x3 kernel + // and don't want to lose frequency bins. + // @param [in] zero_padding_width The number of frames that we zero-pad with on + // the left, and on the right, of the image (time + // dimension). Likely to be 0 in many speech applications, + // since we normally deal with edge effects by padding + // with repeats of the first and last frame; but + // padding is supported by the component. ConvolutionComputation(int32 num_channels_out, int32 num_channels_in, int32 filter_height, int32 filter_width, int32 filter_stride_height, int32 filter_stride_width, - // dilation? int32 filter_dilation_height, int32 filter_dilation_width, int32 num_images, @@ -48,39 +106,65 @@ class ConvolutionComputation { /** * Returns the size of the workspace required for each stage, in - * bytes (not 32-bit words). + * bytes (_not_ 32-bit words). */ size_t TempSpaceRequiredForward() const; size_t TempSpaceRequiredBackwardData() const; size_t TempSpaceRequiredBackwardFilter() const; - // Why aren't these const methods? That would make things a lot simpler + /** + * @param [in] input NWHC fully-packed tensor, with N == NumRows() + * @param [in] params KCWH fully-packed tensor, with K == NumRows() + * @param [in] bias vector of length K + * @param [in/out] temp_space Pointer to pre-allocated memory of size at least + * this->TempSpaceRequiredForward() bytes + * @param [out] output Pre-allocated NWHK fully-packed tensor, with N == NumRows() + */ void ConvolveForward(const CuMatrixBase &input, const CuMatrixBase ¶ms, const CuVectorBase &bias, CuVectorBase *temp_space, CuMatrixBase *output) const; - // Why aren't these const methods? That would make things a lot simpler + /** + * @param [in] params KCWH fully-packed tensor, with K == NumRows() + * @param [in] output_deriv NWHK fully-packed tensor, with N == NumRows() + * @param [in/out] temp_space Pointer to pre-allocated memory of size at least + * this->TempSpaceRequiredBackwardData() bytes + * @param [out] input_deriv Pre-allocated NWHC fully-packed tensor, with N == NumRows() + */ void ConvolveBackwardData(const CuMatrixBase ¶ms, const CuMatrixBase &output_deriv, - CuVectorBase *temp, + CuVectorBase *temp_space, CuMatrixBase *input_deriv) const; - // Why aren't these const methods? That would make things a lot simpler + /** + * @param [in] output_deriv NWHK fully-packed tensor, with N == NumRows() + * @param [in] input NWHC fully-packed tensor, with N == NumRows() + * @param [in] alpha + * params_deriv := alpha * gradient_computed + params_deriv + * @param [in] params KCWH fully-packed tensor, with K == NumRows() + * @param [in/out] temp_space Pointer to pre-allocated memory of size at least + * this->TempSpaceRequiredBackwardFilter() bytes + * @param [out] params_deriv Pre-allocated KCWH fully-packed tensor, with K == NumRows() + */ void ConvolveBackwardParams(const CuMatrixBase &output_deriv, const CuMatrixBase &input, BaseFloat alpha, - CuVectorBase *temp, + CuVectorBase *temp_space, CuMatrixBase *params_deriv) const; - // Why aren't these const methods? That would make things a lot simpler + /** + * @param [in] output_deriv NWHK fully-packed tensor, with N == NumRows() + * @param [in] alpha + * bias_deriv := alpha * gradient_computed + bias_deriv + * @param [out] bias_deriv Pre-allocated vector of length K + */ void ConvolveBackwardBias(const CuMatrixBase &output_deriv, BaseFloat alpha, CuVectorBase *bias_deriv) const; private: - // Need to use the PIMPL idiom so our clients don't need to have cudnn.h cudnnTensorDescriptor_t input_desc_; cudnnTensorDescriptor_t output_desc_; cudnnFilterDescriptor_t params_desc_; @@ -93,33 +177,6 @@ class ConvolutionComputation { cudnnConvolutionBwdDataAlgo_t bwd_data_algo_; }; -/* /\** */ -/* This function does the compilation for a convolution computation; it's */ -/* a wrapper for the functions below, which should not have to be called */ -/* by the end user. */ - -/* @param [in] model The convolution model that this computation is for. */ -/* @param [in] input_indexes The list of Indexes available at the input of */ -/* the computation. */ -/* @param [in] output_indexes The list of Indexes requested to be computed */ -/* at the output of the computation. It is an error if */ -/* all dependencies are not satisfied (specifically: for */ -/* each Index (n,t,x) in 'output_indexes', the Index */ -/* (n,t+time_offset,x) must be present in 'input_indexes' */ -/* for each time_offset in model.required_time_offsets. */ -/* @param [out] computation If non-NULL, the compiled computation will be */ -/* written to this location. */ - -/* *\/ */ -/* void CompileConvolutionComputation( */ -/* const ConvolutionModel& model, */ -/* const std::vector &input_indexes, */ -/* const std::vector &output_indexes, */ -/* const ConvolutionComputationOptions &opts, */ -/* cudnn::ConvolutionComputation *computation, */ -/* std::vector *input_indexes_modified, */ -/* std::vector *output_indexes_modified); */ - } // namespace cudnn } // namespace nnet3 } // namespace kaldi From be6fc2ac32e240c475da3ac542b1040f5cb0c6cd Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Mon, 8 Oct 2018 02:52:32 -0400 Subject: [PATCH 09/22] Make CUDNN mandatory if building with CUDA. There may be edge cases in the configure script with this remaining. We still don't look up the version of CUDNN to download for your CUDA version. Also make sure that non-CUDA builds still compile. This is achieved by the cu-cudnn-helper.h file, which forward declares cudnn types, which fortunately are all pointer types (except for enums, which don't matter in this case). --- src/cudamatrix/cu-common.h | 5 +--- src/cudamatrix/cu-cudnn-helper.h | 42 ++++++++++++++++++++++++++++++++ src/cudamatrix/cu-device.cc | 12 +-------- src/cudamatrix/cu-device.h | 6 +---- src/makefiles/cuda_32bit.mk | 8 ++++-- src/makefiles/cuda_64bit.mk | 9 ++++--- src/nnet3/Makefile | 6 ++++- src/nnet3/convolution-cudnn.h | 5 +--- 8 files changed, 62 insertions(+), 31 deletions(-) create mode 100644 src/cudamatrix/cu-cudnn-helper.h diff --git a/src/cudamatrix/cu-common.h b/src/cudamatrix/cu-common.h index de1ce42ac10..621fb07f6b9 100644 --- a/src/cudamatrix/cu-common.h +++ b/src/cudamatrix/cu-common.h @@ -23,6 +23,7 @@ #ifndef KALDI_CUDAMATRIX_CU_COMMON_H_ #define KALDI_CUDAMATRIX_CU_COMMON_H_ #include "cudamatrix/cu-matrixdim.h" // for CU1DBLOCK and CU2DBLOCK +#include "cudamatrix/cu-cudnn-helper.h" #include #include @@ -74,9 +75,6 @@ } \ } -#if HAVE_CUDNN == 1 -#include - #define CUDNN_SAFE_CALL(fun) \ do { \ cudnnStatus_t ret; \ @@ -85,7 +83,6 @@ do { << "\" returned from '" << #fun << "'"; \ } \ } while(0) -#endif // HAVE_CUDNN == 1 namespace kaldi { diff --git a/src/cudamatrix/cu-cudnn-helper.h b/src/cudamatrix/cu-cudnn-helper.h new file mode 100644 index 00000000000..9d4ab84aa50 --- /dev/null +++ b/src/cudamatrix/cu-cudnn-helper.h @@ -0,0 +1,42 @@ +// cudamatrix/cu-cudnn-helper.h + +// Copyright 2018 Daniel Galvez + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_CUDAMATRIX_CU_CUDNN_HELPER_H_ +#define KALDI_CUDAMATRIX_CU_CUDNN_HELPER_H_ + +#if HAVE_CUDA == 1 +#include +#else +typedef struct cudnnTensorStruct* cudnnTensorDescriptor_t; +typedef struct cudnnConvolutionStruct* cudnnConvolutionDescriptor_t; +typedef struct cudnnPoolingStruct* cudnnPoolingDescriptor_t; +typedef struct cudnnFilterStruct* cudnnFilterDescriptor_t; +typedef struct cudnnLRNStruct* cudnnLRNDescriptor_t; +typedef struct cudnnActivationStruct* cudnnActivationDescriptor_t; +typedef struct cudnnSpatialTransformerStruct* cudnnSpatialTransformerDescriptor_t; +typedef struct cudnnOpTensorStruct* cudnnOpTensorDescriptor_t; +typedef struct cudnnReduceTensorStruct* cudnnReduceTensorDescriptor_t; +typedef struct cudnnCTCLossStruct* cudnnCTCLossDescriptor_t; + +typedef enum {} cudnnConvolutionBwdDataAlgo_t; +typedef enum {} cudnnConvolutionBwdFilterAlgo_t; +typedef enum {} cudnnConvolutionFwdAlgo_t; +#endif // HAVE_CUDA == 1 + +#endif // KALDI_CUDAMATRIX_CU_CUDNN_HELPER_H_ diff --git a/src/cudamatrix/cu-device.cc b/src/cudamatrix/cu-device.cc index 45c74e394e8..eea61d130ae 100644 --- a/src/cudamatrix/cu-device.cc +++ b/src/cudamatrix/cu-device.cc @@ -27,9 +27,6 @@ #include #include #include -#if HAVE_CUDNN == 1 -#include -#endif // HAVE_CUDNN == 1 #include #include @@ -119,10 +116,8 @@ void CuDevice::Initialize() { CUSPARSE_SAFE_CALL(cusparseCreate(&cusparse_handle_)); CUSPARSE_SAFE_CALL(cusparseSetStream(cusparse_handle_, cudaStreamPerThread)); -#if HAVE_CUDNN == 1 CUDNN_SAFE_CALL(cudnnCreate(&cudnn_handle_)); CUDNN_SAFE_CALL(cudnnSetStream(cudnn_handle_, cudaStreamPerThread)); -#endif // HAVE_CUDNN == 1 } } @@ -259,10 +254,8 @@ void CuDevice::FinalizeActiveGpu() { CUSPARSE_SAFE_CALL(cusparseCreate(&cusparse_handle_)); CUSPARSE_SAFE_CALL(cusparseSetStream(cusparse_handle_, cudaStreamPerThread)); -#if HAVE_CUDNN == 1 CUDNN_SAFE_CALL(cudnnCreate(&cudnn_handle_)); CUDNN_SAFE_CALL(cudnnSetStream(cudnn_handle_, cudaStreamPerThread)); -#endif // HAVE_CUDNN == 1 // Notify the user which GPU is being userd. char name[128]; @@ -535,11 +528,8 @@ CuDevice::~CuDevice() { CUBLAS_SAFE_CALL(cublasDestroy(cublas_handle_)); if (cusparse_handle_) CUSPARSE_SAFE_CALL(cusparseDestroy(cusparse_handle_)); -#if HAVE_CUDNN == 1 - if (cudnn_handle_) { + if (cudnn_handle_) CUDNN_SAFE_CALL(cudnnDestroy(cudnn_handle_)); - } -#endif // HAVE_CUDNN == 1 } diff --git a/src/cudamatrix/cu-device.h b/src/cudamatrix/cu-device.h index 95c0d48413b..2ff641d2d1c 100644 --- a/src/cudamatrix/cu-device.h +++ b/src/cudamatrix/cu-device.h @@ -35,11 +35,7 @@ #include "base/kaldi-common.h" #include "base/timer.h" #include "cudamatrix/cu-allocator.h" - -// Forward declare the cudnnHandle_t found in cudnn.h so that we don't -// need to #include . This allows us to make cudnn an -// optional dependency. -typedef struct cudnnContext *cudnnHandle_t; +#include "cudamatrix/cu-cudnn-helper.h" namespace kaldi { diff --git a/src/makefiles/cuda_32bit.mk b/src/makefiles/cuda_32bit.mk index 0e551c038db..c8908ef45af 100644 --- a/src/makefiles/cuda_32bit.mk +++ b/src/makefiles/cuda_32bit.mk @@ -4,11 +4,15 @@ endif ifndef CUDATKDIR $(error CUDATKDIR not defined.) endif +ifndef CUDNNDIR +$(error CUDNNDIR not defined.) +endif + CUDA_INCLUDE= -I$(CUDATKDIR)/include CUDA_FLAGS = -g -Xcompiler -fPIC --verbose --machine 32 -DHAVE_CUDA=1 \ - -DHAVE_CUDNN=1 -ccbin $(CXX) -DKALDI_DOUBLEPRECISION=$(DOUBLE_PRECISION) \ + -ccbin $(CXX) -DKALDI_DOUBLEPRECISION=$(DOUBLE_PRECISION) \ -DCUDA_API_PER_THREAD_DEFAULT_STREAM -CXXFLAGS += -DHAVE_CUDA=1 -DHAVE_CUDNN=1 -I$(CUDATKDIR)/include +CXXFLAGS += -DHAVE_CUDA=1 -I$(CUDATKDIR)/include LDFLAGS += -L$(CUDATKDIR)/lib -Wl,-rpath=$(CUDATKDIR)/lib LDLIBS += -lcudnn -lcublas -lcusparse -lcudart -lcurand #LDLIBS : The libs are loaded later than static libs in implicit rule diff --git a/src/makefiles/cuda_64bit.mk b/src/makefiles/cuda_64bit.mk index e77d70c8109..3fa75cedb75 100644 --- a/src/makefiles/cuda_64bit.mk +++ b/src/makefiles/cuda_64bit.mk @@ -4,18 +4,19 @@ endif ifndef CUDATKDIR $(error CUDATKDIR not defined.) endif +ifndef CUDNNDIR +$(error CUDNNDIR not defined.) +endif # Order matters here. We must tell the compiler to search # $(CUDNNDIR)/lib64 before $(CUDATKDIR)/lib64 because the CUDNN .deb # files install cudnn to /usr/local/cuda/lib64, which would overshadow # the user-specified $(CUDNNDIR) -ifdef CUDNNDIR CUDA_INCLUDE += -I$(CUDNNDIR)/include -CUDA_FLAGS += -DHAVE_CUDNN=1 -CXXFLAGS += -I$(CUDNNDIR)/include -DHAVE_CUDNN=1 +CXXFLAGS += -I$(CUDNNDIR)/include CUDA_LDFLAGS += -L$(CUDNNDIR)/lib64 -Wl,-rpath,$(CUDNNDIR)/lib64 CUDA_LDLIBS += -lcudnn -endif + CUDA_INCLUDE += -I$(CUDATKDIR)/include CUDA_FLAGS += -g -Xcompiler -fPIC --verbose --machine 64 -DHAVE_CUDA=1 \ -ccbin $(CXX) \ diff --git a/src/nnet3/Makefile b/src/nnet3/Makefile index 7984291360c..b94dee5421d 100644 --- a/src/nnet3/Makefile +++ b/src/nnet3/Makefile @@ -29,11 +29,15 @@ OBJFILES = nnet-common.o nnet-compile.o nnet-component-itf.o \ nnet-discriminative-diagnostics.o \ discriminative-training.o nnet-discriminative-training.o \ nnet-compile-looped.o decodable-simple-looped.o \ - decodable-online-looped.o convolution.o convolution-cudnn.o \ + decodable-online-looped.o convolution.o \ nnet-convolutional-component.o attention.o \ nnet-attention-component.o nnet-tdnn-component.o +ifeq ($(CUDA), true) +OBJFILES += convolution-cudnn.o +endif + LIBNAME = kaldi-nnet3 ADDLIBS = ../chain/kaldi-chain.a ../cudamatrix/kaldi-cudamatrix.a \ diff --git a/src/nnet3/convolution-cudnn.h b/src/nnet3/convolution-cudnn.h index e6c37667f55..bfd5403b5d9 100644 --- a/src/nnet3/convolution-cudnn.h +++ b/src/nnet3/convolution-cudnn.h @@ -25,10 +25,7 @@ #include "matrix/matrix-lib.h" #include "nnet3/convolution.h" -// TODO: Consider forward declaring types like -// cudnnTensorDescriptor_t, so that this header file doesn't depend on -// cudnn.h -#include +#include namespace kaldi { namespace nnet3 { From 5bb72e567a9cc1238594c401477fbc0b933baad0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 12 Oct 2018 23:07:34 -0400 Subject: [PATCH 10/22] [src] Updates to docs in convolution branch --- src/nnet3/convolution-cudnn.cc | 278 +++++++++++++++++---------------- src/nnet3/convolution-cudnn.h | 135 ++++++++-------- 2 files changed, 214 insertions(+), 199 deletions(-) diff --git a/src/nnet3/convolution-cudnn.cc b/src/nnet3/convolution-cudnn.cc index 723540a1c66..f39db408129 100644 --- a/src/nnet3/convolution-cudnn.cc +++ b/src/nnet3/convolution-cudnn.cc @@ -25,69 +25,71 @@ namespace cudnn { namespace { - const BaseFloat ONE(1.0); - const BaseFloat ZERO(0.0); +// Note: anonymous namespaces are now preferred (by the C++ standard) over +// static variables. +const BaseFloat ONE(1.0); +const BaseFloat ZERO(0.0); } - ConvolutionComputation:: - ConvolutionComputation(int32 num_channels_out, int32 num_channels_in, - int32 filter_height, int32 filter_width, - int32 filter_stride_height, int32 filter_stride_width, - int32 filter_dilation_height, - int32 filter_dilation_width, - int32 num_images, - int32 input_image_height, int32 input_image_width, - int32 zero_padding_height, int32 zero_padding_width) { - CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&input_desc_)); - CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&output_desc_)); - CUDNN_SAFE_CALL(cudnnCreateFilterDescriptor(¶ms_desc_)); - CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&bias_desc_)); - CUDNN_SAFE_CALL(cudnnCreateConvolutionDescriptor(&conv_desc_)); - CUDNN_SAFE_CALL(cudnnCreateActivationDescriptor(&activation_desc_)); +ConvolutionComputation:: +ConvolutionComputation(int32 num_channels_out, int32 num_channels_in, + int32 filter_height, int32 filter_width, + int32 filter_stride_height, int32 filter_stride_width, + int32 filter_dilation_height, + int32 filter_dilation_width, + int32 num_images, + int32 input_image_height, int32 input_image_width, + int32 zero_padding_height, int32 zero_padding_width) { + CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&input_desc_)); + CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&output_desc_)); + CUDNN_SAFE_CALL(cudnnCreateFilterDescriptor(¶ms_desc_)); + CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&bias_desc_)); + CUDNN_SAFE_CALL(cudnnCreateConvolutionDescriptor(&conv_desc_)); + CUDNN_SAFE_CALL(cudnnCreateActivationDescriptor(&activation_desc_)); - CUDNN_SAFE_CALL( - cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NHWC, - CUDNN_DATA_FLOAT, num_images, - num_channels_in, input_image_width, - input_image_height)); - CUDNN_SAFE_CALL( - cudnnSetConvolution2dDescriptor(conv_desc_, - zero_padding_width, zero_padding_height, - filter_stride_width, filter_stride_height, - filter_dilation_width, filter_dilation_height, - CUDNN_CROSS_CORRELATION, // TODO: Double check this! - CUDNN_DATA_FLOAT)); - CUDNN_SAFE_CALL( - cudnnSetFilter4dDescriptor(params_desc_, CUDNN_DATA_FLOAT, - CUDNN_TENSOR_NCHW, num_channels_out, - num_channels_in, filter_width, filter_height)); + CUDNN_SAFE_CALL( + cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NHWC, + CUDNN_DATA_FLOAT, num_images, + num_channels_in, input_image_width, + input_image_height)); + CUDNN_SAFE_CALL( + cudnnSetConvolution2dDescriptor(conv_desc_, + zero_padding_width, zero_padding_height, + filter_stride_width, filter_stride_height, + filter_dilation_width, filter_dilation_height, + CUDNN_CROSS_CORRELATION, // TODO: Double check this! + CUDNN_DATA_FLOAT)); + CUDNN_SAFE_CALL( + cudnnSetFilter4dDescriptor(params_desc_, CUDNN_DATA_FLOAT, + CUDNN_TENSOR_NCHW, num_channels_out, + num_channels_in, filter_width, filter_height)); - // These two member functions depend only on input_desc_, - // conv_desc_, and params_desc_, so they are safe to call now. - int32 out_kaldi_height_cudnn_width = OutputImageHeight(); - int32 out_kaldi_width_cudnn_height = OutputImageWidth(); - CUDNN_SAFE_CALL( - cudnnSetTensor4dDescriptor(output_desc_, CUDNN_TENSOR_NHWC, - CUDNN_DATA_FLOAT, num_images, - num_channels_in, out_kaldi_width_cudnn_height, - out_kaldi_height_cudnn_width)); - const int32 bias_stride[] = {1}; - CUDNN_SAFE_CALL( - cudnnSetTensorNdDescriptor(bias_desc_, CUDNN_DATA_FLOAT, 1, - &num_channels_out, bias_stride)); + // These two member functions depend only on input_desc_, + // conv_desc_, and params_desc_, so they are safe to call now. + int32 out_kaldi_height_cudnn_width = OutputImageHeight(); + int32 out_kaldi_width_cudnn_height = OutputImageWidth(); + CUDNN_SAFE_CALL( + cudnnSetTensor4dDescriptor(output_desc_, CUDNN_TENSOR_NHWC, + CUDNN_DATA_FLOAT, num_images, + num_channels_in, out_kaldi_width_cudnn_height, + out_kaldi_height_cudnn_width)); + const int32 bias_stride[] = {1}; + CUDNN_SAFE_CALL( + cudnnSetTensorNdDescriptor(bias_desc_, CUDNN_DATA_FLOAT, 1, + &num_channels_out, bias_stride)); - const double DONT_CARE = 0; - CUDNN_SAFE_CALL( + const double DONT_CARE = 0; + CUDNN_SAFE_CALL( cudnnSetActivationDescriptor(activation_desc_, CUDNN_ACTIVATION_IDENTITY, CUDNN_PROPAGATE_NAN, DONT_CARE)); - int32 requested_algo_count, returned_algo_count; - CUDNN_SAFE_CALL(cudnnGetConvolutionForwardAlgorithmMaxCount( + int32 requested_algo_count, returned_algo_count; + CUDNN_SAFE_CALL(cudnnGetConvolutionForwardAlgorithmMaxCount( CuDevice::Instantiate().GetCudnnHandle(), &requested_algo_count)); - cudnnConvolutionFwdAlgoPerf_t *forward_results = + cudnnConvolutionFwdAlgoPerf_t *forward_results = new cudnnConvolutionFwdAlgoPerf_t[requested_algo_count]; - CUDNN_SAFE_CALL(cudnnFindConvolutionForwardAlgorithm( + CUDNN_SAFE_CALL(cudnnFindConvolutionForwardAlgorithm( CuDevice::Instantiate().GetCudnnHandle(), input_desc_, params_desc_, @@ -97,17 +99,17 @@ namespace { &returned_algo_count, forward_results)); - KALDI_ASSERT(returned_algo_count > 0 && - "No algorithms were returned by CUDNN."); - const cudnnConvolutionFwdAlgoPerf_t& best_forward = forward_results[0]; - fwd_algo_ = best_forward.algo; - delete [] forward_results; + KALDI_ASSERT(returned_algo_count > 0 && + "No algorithms were returned by CUDNN."); + const cudnnConvolutionFwdAlgoPerf_t& best_forward = forward_results[0]; + fwd_algo_ = best_forward.algo; + delete [] forward_results; - CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( + CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( CuDevice::Instantiate().GetCudnnHandle(), &requested_algo_count)); - cudnnConvolutionBwdFilterAlgoPerf_t *backward_filter_results = + cudnnConvolutionBwdFilterAlgoPerf_t *backward_filter_results = new cudnnConvolutionBwdFilterAlgoPerf_t[requested_algo_count]; - CUDNN_SAFE_CALL(cudnnFindConvolutionBackwardFilterAlgorithm( + CUDNN_SAFE_CALL(cudnnFindConvolutionBackwardFilterAlgorithm( CuDevice::Instantiate().GetCudnnHandle(), input_desc_, output_desc_, @@ -116,18 +118,18 @@ namespace { requested_algo_count, &returned_algo_count, backward_filter_results)); - KALDI_ASSERT(returned_algo_count > 0 && - "No algorithms were returned by CUDNN."); - const cudnnConvolutionBwdFilterAlgoPerf_t& best_backward_filter = + KALDI_ASSERT(returned_algo_count > 0 && + "No algorithms were returned by CUDNN."); + const cudnnConvolutionBwdFilterAlgoPerf_t& best_backward_filter = backward_filter_results[0]; - bwd_filter_algo_ = best_backward_filter.algo; - delete [] backward_filter_results; + bwd_filter_algo_ = best_backward_filter.algo; + delete [] backward_filter_results; - CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( + CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( CuDevice::Instantiate().GetCudnnHandle(), &requested_algo_count)); - cudnnConvolutionBwdDataAlgoPerf_t *backward_data_results = + cudnnConvolutionBwdDataAlgoPerf_t *backward_data_results = new cudnnConvolutionBwdDataAlgoPerf_t[requested_algo_count]; - CUDNN_SAFE_CALL(cudnnFindConvolutionBackwardDataAlgorithm( + CUDNN_SAFE_CALL(cudnnFindConvolutionBackwardDataAlgorithm( CuDevice::Instantiate().GetCudnnHandle(), params_desc_, output_desc_, @@ -136,50 +138,50 @@ namespace { requested_algo_count, &returned_algo_count, backward_data_results)); - KALDI_ASSERT(returned_algo_count > 0 && - "No algorithms were returned by CUDNN."); - const cudnnConvolutionBwdDataAlgoPerf_t& best_backward_data = + KALDI_ASSERT(returned_algo_count > 0 && + "No algorithms were returned by CUDNN."); + const cudnnConvolutionBwdDataAlgoPerf_t& best_backward_data = backward_data_results[0]; - bwd_data_algo_ = best_backward_data.algo; - delete [] backward_data_results; - } + bwd_data_algo_ = best_backward_data.algo; + delete [] backward_data_results; +} - ConvolutionComputation::~ConvolutionComputation() { - CUDNN_SAFE_CALL(cudnnDestroyTensorDescriptor(input_desc_)); - CUDNN_SAFE_CALL(cudnnDestroyTensorDescriptor(output_desc_)); - CUDNN_SAFE_CALL(cudnnDestroyFilterDescriptor(params_desc_)); - CUDNN_SAFE_CALL(cudnnDestroyTensorDescriptor(bias_desc_)); - CUDNN_SAFE_CALL(cudnnDestroyConvolutionDescriptor(conv_desc_)); - CUDNN_SAFE_CALL(cudnnDestroyActivationDescriptor(activation_desc_)); - } +ConvolutionComputation::~ConvolutionComputation() { + CUDNN_SAFE_CALL(cudnnDestroyTensorDescriptor(input_desc_)); + CUDNN_SAFE_CALL(cudnnDestroyTensorDescriptor(output_desc_)); + CUDNN_SAFE_CALL(cudnnDestroyFilterDescriptor(params_desc_)); + CUDNN_SAFE_CALL(cudnnDestroyTensorDescriptor(bias_desc_)); + CUDNN_SAFE_CALL(cudnnDestroyConvolutionDescriptor(conv_desc_)); + CUDNN_SAFE_CALL(cudnnDestroyActivationDescriptor(activation_desc_)); +} - int32 ConvolutionComputation::OutputImageHeight() const { - int32 unused; - int32 kaldi_height_cudnn_width; - CUDNN_SAFE_CALL( +int32 ConvolutionComputation::OutputImageHeight() const { + int32 unused; + int32 kaldi_height_cudnn_width; + CUDNN_SAFE_CALL( cudnnGetConvolution2dForwardOutputDim(conv_desc_, input_desc_, params_desc_, &unused, &unused, &unused, &kaldi_height_cudnn_width)); - return kaldi_height_cudnn_width; - } + return kaldi_height_cudnn_width; +} - int32 ConvolutionComputation::OutputImageWidth() const { - int32 unused; - int32 kaldi_width_cudnn_height; - CUDNN_SAFE_CALL( +int32 ConvolutionComputation::OutputImageWidth() const { + int32 unused; + int32 kaldi_width_cudnn_height; + CUDNN_SAFE_CALL( cudnnGetConvolution2dForwardOutputDim(conv_desc_, input_desc_, params_desc_, &unused, &unused, &kaldi_width_cudnn_height, &unused)); - return kaldi_width_cudnn_height; - } + return kaldi_width_cudnn_height; +} - size_t ConvolutionComputation::TempSpaceRequiredForward() const { - size_t workspace_size_bytes; - CUDNN_SAFE_CALL(cudnnGetConvolutionForwardWorkspaceSize( +size_t ConvolutionComputation::TempSpaceRequiredForward() const { + size_t workspace_size_bytes; + CUDNN_SAFE_CALL(cudnnGetConvolutionForwardWorkspaceSize( CuDevice::Instantiate().GetCudnnHandle(), input_desc_, params_desc_, @@ -187,12 +189,12 @@ namespace { output_desc_, fwd_algo_, &workspace_size_bytes)); - return workspace_size_bytes; - } + return workspace_size_bytes; +} - size_t ConvolutionComputation::TempSpaceRequiredBackwardData() const { - size_t workspace_size_bytes; - CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize( +size_t ConvolutionComputation::TempSpaceRequiredBackwardData() const { + size_t workspace_size_bytes; + CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize( CuDevice::Instantiate().GetCudnnHandle(), params_desc_, output_desc_, @@ -200,13 +202,13 @@ namespace { input_desc_, bwd_data_algo_, &workspace_size_bytes)); - return workspace_size_bytes; - } + return workspace_size_bytes; +} - size_t ConvolutionComputation::TempSpaceRequiredBackwardFilter() const { - size_t workspace_size_bytes; - CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize( +size_t ConvolutionComputation::TempSpaceRequiredBackwardFilter() const { + size_t workspace_size_bytes; + CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize( CuDevice::Instantiate().GetCudnnHandle(), input_desc_, output_desc_, @@ -214,18 +216,18 @@ namespace { params_desc_, bwd_filter_algo_, &workspace_size_bytes)); - return workspace_size_bytes; - } + return workspace_size_bytes; +} - void ConvolutionComputation:: - ConvolveForward(const CuMatrixBase &input, - const CuMatrixBase ¶ms, - const CuVectorBase &bias, - CuVectorBase *temp_space, - CuMatrixBase *output) const { - CUDNN_SAFE_CALL(cudnnConvolutionBiasActivationForward( +void ConvolutionComputation:: +ConvolveForward(const CuMatrixBase &input, + const CuMatrixBase ¶ms, + const CuVectorBase &bias, + CuVectorBase *temp_space, + CuMatrixBase *output) const { + CUDNN_SAFE_CALL(cudnnConvolutionBiasActivationForward( CuDevice::Instantiate().GetCudnnHandle(), &ONE, input_desc_, @@ -244,14 +246,14 @@ namespace { activation_desc_, output_desc_, output->Data())); - } +} - void ConvolutionComputation:: - ConvolveBackwardData(const CuMatrixBase ¶ms, - const CuMatrixBase &output_deriv, - CuVectorBase *temp, - CuMatrixBase *input_deriv) const { - CUDNN_SAFE_CALL(cudnnConvolutionBackwardData( +void ConvolutionComputation:: +ConvolveBackwardData(const CuMatrixBase ¶ms, + const CuMatrixBase &output_deriv, + CuVectorBase *temp, + CuMatrixBase *input_deriv) const { + CUDNN_SAFE_CALL(cudnnConvolutionBackwardData( CuDevice::Instantiate().GetCudnnHandle(), &ONE, params_desc_, @@ -265,15 +267,15 @@ namespace { &ZERO, input_desc_, input_deriv->Data())); - } +} - void ConvolutionComputation:: - ConvolveBackwardParams(const CuMatrixBase &output_deriv, - const CuMatrixBase &input, - BaseFloat alpha, - CuVectorBase *temp, - CuMatrixBase *params_deriv) const { - CUDNN_SAFE_CALL(cudnnConvolutionBackwardFilter( +void ConvolutionComputation:: +ConvolveBackwardParams(const CuMatrixBase &output_deriv, + const CuMatrixBase &input, + BaseFloat alpha, + CuVectorBase *temp, + CuMatrixBase *params_deriv) const { + CUDNN_SAFE_CALL(cudnnConvolutionBackwardFilter( CuDevice::Instantiate().GetCudnnHandle(), &alpha, input_desc_, @@ -287,13 +289,13 @@ namespace { &ONE, params_desc_, params_deriv->Data())); - } +} - void ConvolutionComputation:: - ConvolveBackwardBias(const CuMatrixBase &output_deriv, - BaseFloat alpha, - CuVectorBase *bias_deriv) const { - CUDNN_SAFE_CALL(cudnnConvolutionBackwardBias( +void ConvolutionComputation:: +ConvolveBackwardBias(const CuMatrixBase &output_deriv, + BaseFloat alpha, + CuVectorBase *bias_deriv) const { + CUDNN_SAFE_CALL(cudnnConvolutionBackwardBias( CuDevice::Instantiate().GetCudnnHandle(), &alpha, output_desc_, @@ -301,7 +303,7 @@ namespace { &ONE, bias_desc_, bias_deriv->Data())); - } +} } // namespace cudnn } // namespace nnet3 diff --git a/src/nnet3/convolution-cudnn.h b/src/nnet3/convolution-cudnn.h index bfd5403b5d9..8d3e6cdea21 100644 --- a/src/nnet3/convolution-cudnn.h +++ b/src/nnet3/convolution-cudnn.h @@ -31,64 +31,68 @@ namespace kaldi { namespace nnet3 { namespace cudnn { +/** + Represents structural information about a convolution computation, with + filters, padding, striding, inputs and outputs of a specified size. The + same interface is usable on both GPU and CPU. You create this object only + after you know the number of images and input and output sizes, and it will + be stored as part of a NnetComputation (i.e. a compiled computation) and + re-used between different minibatches. This object is lightweight; it + doesn't contain data, only a few integers and descriptors. + + In the following docstrings, consider: + N to be equivalent to num_images + C to be equivalent to num_channels_in + K to be equivalent to num_channels_out + H to be equivalent to input_image_height (for images) or + filter_height (for filter parameters). + W to be equivalent to input_image_width (for images) or + filter_width (for filter parameters). + + + @param [in] num_channels_out Number of output channels, e.g. 64. + @param [in] num_channels_in Number of input channels, e.g. 32. + @param [in] filter_height Height of filter patch, e.g. 3 (for 3x3 kernel). Corresponds + to the 'frequency' dimension in normal speech applications, or + height in OCR applications. + @param [in] filter_width Width of filter patch, e.g. 3 (for 3x3 kernel). Corresponds + to the 'time' dimension in normal speech applications. + @param [in] filter_stride_height Filter stride in the height ('frequency') dimension. + Will normally be 1 in speech and OCR applications. + @param [in] filter_stride_width Filter stride in the width ('time') dimension. + Will usually be 1 in most layers, but may be 2 or 3 if + we are doing subsampling on this layer (e.g. in + reduced-frame-rate models like chain models). + @param [in] filter_dilation_height Filter dilation in the height ('frequency') + dimension. Equals the stride, in the input image, of + individual elements of the filter patch. Will + normally be 1. + @param [in] filter_dilation_width Filter dilation in the width ('time') + dimension. Will normally be 1, but could + be more than one if, for instance, you have components + with time-stride > 1 which for some reason are required + to be evaluated on every frame. + @param [in] num_images The number of images we are processing, generally + equal to the minibatch size. + @param [in] input_image_height The height of the input images. Corresponds to + the number of frequency bins, in speech applications. + @param [in] input_image_width The width of the input images. Corresponds to + the number of time frames on the input, in speech + applications. + @param [in] zero_padding_height The number of pixels that we zero-pad with on + the bottom, and on the top, of the image (the + frequency dimension, in speech applications). Would + be 1, for instance, if you are using a 3x3 kernel + and don't want to lose frequency bins. + @param [in] zero_padding_width The number of frames that we zero-pad with on + the left, and on the right, of the image (time + dimension). Likely to be 0 in many speech applications, + since we normally deal with edge effects by padding + with repeats of the first and last frame; but + padding is supported by the component. +*/ class ConvolutionComputation final { public: - // Represents structural information about a convolution computation, - // with filters, padding, striding, inputs and outputs of a specified size. The same interface - // is usable on both GPU and CPU. You create this object only after you know the - // number of images and input and output sizes, and it will be stored as part of - // a NnetComputation (i.e. a compiled computation) and re-used between different - // minibatches. This object is lightweight. - // - // In the following docstrings, consider: - // N to be equivalent to num_images - // C to be equivalent to num_channels_in - // K to be equivalent to num_channels_out - // H to be equivalent to input_image_height, or filter_height, - // depending on context - // W to be equivalent to input_image_width, or filter_width, - // depending on context - // - // @param [in] num_channels_out Number of output channels, e.g. 64. - // @param [in] num_channels_in Number of input channels, e.g. 32. - // @param [in] filter_height Height of filter patch, e.g. 3 (for 3x3 kernel). Corresponds - // to the 'frequency' dimension in normal speech applications, or - // height in OCR applications. - // @param [in] filter_width Width of filter patch, e.g. 3 (for 3x3 kernel). Corresponds - // to the 'time' dimension in normal speech applications. - // @param [in] filter_stride_height Filter stride in the height ('frequency') dimension. - // Will normally be 1 in speech and OCR applications. - // @param [in] filter_stride_width Filter stride in the width ('time') dimension. - // Will usually be 1 in most layers, but may be 2 or 3 if - // we are doing subsampling on this layer (e.g. in - // reduced-frame-rate models like chain models). - // @param [in] filter_dilation_height Filter dilation in the height ('frequency') - // dimension. Equals the stride, in the input image, of - // individual elements of the filter patch. Will - // normally be 1. - // @param [in] filter_dilation_width Filter dilation in the width ('time') - // dimension. Will normally be 1, but could - // be more than one if, for instance, you have components - // with time-stride > 1 which for some reason are required - // to be evaluated on every frame. - // @param [in] num_images The number of images we are processing, generally - // equal to the minibatch size. - // @param [in] input_image_height The height of the input images. Corresponds to - // the number of frequency bins, in speech applications. - // @param [in] input_image_width The width of the input images. Corresponds to - // the number of time frames on the input, in speech - // applications. - // @param [in] zero_padding_height The number of pixels that we zero-pad with on - // the bottom, and on the top, of the image (the - // frequency dimension, in speech applications). Would - // be 1, for instance, if you are using a 3x3 kernel - // and don't want to lose frequency bins. - // @param [in] zero_padding_width The number of frames that we zero-pad with on - // the left, and on the right, of the image (time - // dimension). Likely to be 0 in many speech applications, - // since we normally deal with edge effects by padding - // with repeats of the first and last frame; but - // padding is supported by the component. ConvolutionComputation(int32 num_channels_out, int32 num_channels_in, int32 filter_height, int32 filter_width, int32 filter_stride_height, int32 filter_stride_width, @@ -110,10 +114,19 @@ class ConvolutionComputation final { size_t TempSpaceRequiredBackwardFilter() const; /** + * For an explanation of the notation below (e.g. NWHC), see the + * explanation for those variable names in the documentation for this + * class above. Variables that come first have the higher stride. + * + * Caution: for convenience, given the way nnet3 works, we flip the notion of + * height and width that CUDNN uses, so our height is CUDNN's width, and vice + * versa. This is not visible to the user; we mention it just in case + * those familiar with CUDNN get surprised at the order + * * @param [in] input NWHC fully-packed tensor, with N == NumRows() * @param [in] params KCWH fully-packed tensor, with K == NumRows() * @param [in] bias vector of length K - * @param [in/out] temp_space Pointer to pre-allocated memory of size at least + * @param [in/out] temp_space Pointer to pre-allocated memory of size at least * this->TempSpaceRequiredForward() bytes * @param [out] output Pre-allocated NWHK fully-packed tensor, with N == NumRows() */ @@ -126,7 +139,7 @@ class ConvolutionComputation final { /** * @param [in] params KCWH fully-packed tensor, with K == NumRows() * @param [in] output_deriv NWHK fully-packed tensor, with N == NumRows() - * @param [in/out] temp_space Pointer to pre-allocated memory of size at least + * @param [in/out] temp_space Pointer to pre-allocated memory of size at least * this->TempSpaceRequiredBackwardData() bytes * @param [out] input_deriv Pre-allocated NWHC fully-packed tensor, with N == NumRows() */ @@ -138,10 +151,10 @@ class ConvolutionComputation final { /** * @param [in] output_deriv NWHK fully-packed tensor, with N == NumRows() * @param [in] input NWHC fully-packed tensor, with N == NumRows() - * @param [in] alpha + * @param [in] alpha * params_deriv := alpha * gradient_computed + params_deriv * @param [in] params KCWH fully-packed tensor, with K == NumRows() - * @param [in/out] temp_space Pointer to pre-allocated memory of size at least + * @param [in/out] temp_space Pointer to pre-allocated memory of size at least * this->TempSpaceRequiredBackwardFilter() bytes * @param [out] params_deriv Pre-allocated KCWH fully-packed tensor, with K == NumRows() */ @@ -153,7 +166,7 @@ class ConvolutionComputation final { /** * @param [in] output_deriv NWHK fully-packed tensor, with N == NumRows() - * @param [in] alpha + * @param [in] alpha * bias_deriv := alpha * gradient_computed + bias_deriv * @param [out] bias_deriv Pre-allocated vector of length K */ From 964749410918bbc6956ed69dc4d2387bc295da71 Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Sat, 13 Oct 2018 01:55:04 -0400 Subject: [PATCH 11/22] Automatically download CUDNN as part of configure. Remove support for CUDA versions that don't support CUDNN v7. Remove 32-bit CUDA mode, since no recent version of CUDA supports it. I may not have removed all instances of it. --- .gitignore | 1 + src/configure | 51 +++++++++++++++++++++++++++++-------- src/makefiles/cuda_32bit.mk | 18 ------------- 3 files changed, 41 insertions(+), 29 deletions(-) delete mode 100644 src/makefiles/cuda_32bit.mk diff --git a/.gitignore b/.gitignore index cdcd13ec8b5..9c383f1a892 100644 --- a/.gitignore +++ b/.gitignore @@ -144,3 +144,4 @@ GSYMS /tools/mmseg-1.3.0.tar.gz /tools/mmseg-1.3.0/ /kaldiwin_vs* +/tools/cudnn/ diff --git a/src/configure b/src/configure index 29d80b76c67..1adaf388c92 100755 --- a/src/configure +++ b/src/configure @@ -423,9 +423,6 @@ function configure_cuda { fi case $CUDA_VERSION in - 5_5) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35" ;; - 6_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50" ;; - 7_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_53,code=sm_53" ;; 8_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62" ;; 9_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70" ;; 10_*) CUDA_ARCH="-gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_72,code=sm_72 -gencode arch=compute_75,code=sm_75" ;; @@ -439,9 +436,10 @@ function configure_cuda { echo CUDA = true >> kaldi.mk echo CUDATKDIR = $CUDATKDIR >> kaldi.mk echo "CUDA_ARCH = $CUDA_ARCH" >> kaldi.mk - echo CUDNNDIR = $CUDNNDIR >> kaldi.mk echo >> kaldi.mk + configure_cudnn + # 64bit/32bit? We do not support cross compilation with CUDA so, use direct calls to uname -m here if [ "`uname -m`" == "x86_64" ]; then if [ "`uname`" == "Darwin" ]; then @@ -454,7 +452,7 @@ function configure_cuda { elif [ "`uname -m`" == "ppc64le" ]; then cat makefiles/cuda_64bit.mk >> kaldi.mk else - cat makefiles/cuda_32bit.mk >> kaldi.mk + echo "Unexpected architecture `uname -m`"; exit 1 fi else @@ -462,15 +460,46 @@ function configure_cuda { echo "and cuda toolkit, try using --cudatk-dir=... option. Note: this is" echo "only relevant for neural net experiments" fi +} - if [ ! -z $CUDNNDIR ]; then - if [ ! -f $CUDNNDIR/lib64/libcudnn.so ] | - [ ! -f $CUDNNDIR/include/cudnn.h ]; then - echo "CUDNNDIR(=$CUDNNDIR) invalid!" - fi +function configure_cudnn { + if [ -z $CUDNNDIR ]; then + download_appropriate_cudnn + fi + + echo CUDNNDIR = $CUDNNDIR >> kaldi.mk + echo >> kaldi.mk + + if [ ! -f $CUDNNDIR/lib64/libcudnn.so ] | + [ ! -f $CUDNNDIR/include/cudnn.h ]; then + echo "CUDNNDIR(=$CUDNNDIR) invalid!" + fi +} - +function download_appropriate_cudnn { + CUDNNDIR=`rel2abs ../tools/cudnn/cuda/` + + if [ -f $CUDNNDIR/include/cudnn.h ]; then + echo -n "CUDNN has been downloaded already. If you'd like to redownload it " + echo -n "(e.g., because you changed CUDA version), please delete $CUDNNDIR " + echo "and rerun configure" + return fi + + local cudnn_url + case $CUDA_VERSION in + 8_0) cudnn_url="http://developer.download.nvidia.com/compute/redist/cudnn/v7.1.2/cudnn-8.0-linux-x64-v7.1.tgz" ;; + 9_0) cudnn_url="http://developer.download.nvidia.com/compute/redist/cudnn/v7.3.1/cudnn-9.0-linux-x64-v7.3.1.20.tgz" ;; + 9_1) cudnn_url="http://developer.download.nvidia.com/compute/redist/cudnn/v7.1.2/cudnn-9.1-linux-x64-v7.1.tgz" ;; + 9_2) cudnn_url="http://developer.download.nvidia.com/compute/redist/cudnn/v7.2.1/cudnn-9.2-linux-x64-v7.2.1.38.tgz" ;; + 10_0) cudnn_url="http://developer.download.nvidia.com/compute/redist/cudnn/v7.3.1/cudnn-10.0-linux-x64-v7.3.1.20.tgz" ;; + *) echo "No known CUDNN download for provided CUDA_VERSION. Try checking here to see if your CUDA version supports a reasonably new version of CUDNN: https://gitlab.com/nvidia/cuda/tree/centos7"; exit 1 ;; + esac + + local extract_dir=$CUDNNDIR/.. + mkdir -p $extract_dir + wget -T 10 -t 3 $cudnn_url -O $extract_dir/cudnn.tgz + tar --no-same-owner -xzf $extract_dir/cudnn.tgz -C $extract_dir } function linux_configure_speex { diff --git a/src/makefiles/cuda_32bit.mk b/src/makefiles/cuda_32bit.mk deleted file mode 100644 index c8908ef45af..00000000000 --- a/src/makefiles/cuda_32bit.mk +++ /dev/null @@ -1,18 +0,0 @@ -ifndef DOUBLE_PRECISION -$(error DOUBLE_PRECISION not defined.) -endif -ifndef CUDATKDIR -$(error CUDATKDIR not defined.) -endif -ifndef CUDNNDIR -$(error CUDNNDIR not defined.) -endif - - -CUDA_INCLUDE= -I$(CUDATKDIR)/include -CUDA_FLAGS = -g -Xcompiler -fPIC --verbose --machine 32 -DHAVE_CUDA=1 \ - -ccbin $(CXX) -DKALDI_DOUBLEPRECISION=$(DOUBLE_PRECISION) \ - -DCUDA_API_PER_THREAD_DEFAULT_STREAM -CXXFLAGS += -DHAVE_CUDA=1 -I$(CUDATKDIR)/include -LDFLAGS += -L$(CUDATKDIR)/lib -Wl,-rpath=$(CUDATKDIR)/lib -LDLIBS += -lcudnn -lcublas -lcusparse -lcudart -lcurand #LDLIBS : The libs are loaded later than static libs in implicit rule From 53d62afa7adfeca2fd156174d84d63b4a0b59f76 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 15 Oct 2018 16:04:41 -0400 Subject: [PATCH 12/22] [src,build] Get it to compile; some structural changes. --- src/configure | 13 +- src/cudamatrix/cu-common.h | 1 - src/cudamatrix/cu-cudnn-helper.h | 42 --- src/cudamatrix/cu-device.h | 2 +- src/matrix/kaldi-matrix.cc | 9 +- src/matrix/kaldi-matrix.h | 7 +- src/nnet3/Makefile | 5 +- src/nnet3/convolution-cudnn.cc | 491 +++++++++++++++++++++++-------- src/nnet3/convolution-cudnn.h | 167 ++++++++--- 9 files changed, 517 insertions(+), 220 deletions(-) delete mode 100644 src/cudamatrix/cu-cudnn-helper.h diff --git a/src/configure b/src/configure index 1adaf388c92..b09488ac75b 100755 --- a/src/configure +++ b/src/configure @@ -477,7 +477,9 @@ function configure_cudnn { } function download_appropriate_cudnn { - CUDNNDIR=`rel2abs ../tools/cudnn/cuda/` + local tools=`rel2abs ../tools` + install_dir=$tools/cudnn + CUDNNDIR=$tools/cudnn/cuda if [ -f $CUDNNDIR/include/cudnn.h ]; then echo -n "CUDNN has been downloaded already. If you'd like to redownload it " @@ -496,10 +498,9 @@ function download_appropriate_cudnn { *) echo "No known CUDNN download for provided CUDA_VERSION. Try checking here to see if your CUDA version supports a reasonably new version of CUDNN: https://gitlab.com/nvidia/cuda/tree/centos7"; exit 1 ;; esac - local extract_dir=$CUDNNDIR/.. - mkdir -p $extract_dir - wget -T 10 -t 3 $cudnn_url -O $extract_dir/cudnn.tgz - tar --no-same-owner -xzf $extract_dir/cudnn.tgz -C $extract_dir + mkdir -p $install_dir + wget -T 10 -t 3 $cudnn_url -O $install_dir/cudnn.tgz + tar --no-same-owner -xzf $install_dir/cudnn.tgz -C $install_dir } function linux_configure_speex { @@ -1366,7 +1367,7 @@ elif [ "`uname`" == "Linux" ]; then elif [ -f $OPENBLASROOT/include/openblas/cblas.h ] ; then # in REDHAT/CentOS/Ubuntu package installs, the includes are located here OPENBLASINCDIR=$OPENBLASROOT/include/openblas - else + else echo "$0: ***** Using OpenBlas from $OPENBLASROOT but cblas.h is not found. " echo " ****** Assuming openblas is aleady in a default include path, but" echo " ***** if you get compilation messages about not finding files like cblas.h," diff --git a/src/cudamatrix/cu-common.h b/src/cudamatrix/cu-common.h index 621fb07f6b9..42a0a0347d2 100644 --- a/src/cudamatrix/cu-common.h +++ b/src/cudamatrix/cu-common.h @@ -23,7 +23,6 @@ #ifndef KALDI_CUDAMATRIX_CU_COMMON_H_ #define KALDI_CUDAMATRIX_CU_COMMON_H_ #include "cudamatrix/cu-matrixdim.h" // for CU1DBLOCK and CU2DBLOCK -#include "cudamatrix/cu-cudnn-helper.h" #include #include diff --git a/src/cudamatrix/cu-cudnn-helper.h b/src/cudamatrix/cu-cudnn-helper.h deleted file mode 100644 index 9d4ab84aa50..00000000000 --- a/src/cudamatrix/cu-cudnn-helper.h +++ /dev/null @@ -1,42 +0,0 @@ -// cudamatrix/cu-cudnn-helper.h - -// Copyright 2018 Daniel Galvez - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_CUDAMATRIX_CU_CUDNN_HELPER_H_ -#define KALDI_CUDAMATRIX_CU_CUDNN_HELPER_H_ - -#if HAVE_CUDA == 1 -#include -#else -typedef struct cudnnTensorStruct* cudnnTensorDescriptor_t; -typedef struct cudnnConvolutionStruct* cudnnConvolutionDescriptor_t; -typedef struct cudnnPoolingStruct* cudnnPoolingDescriptor_t; -typedef struct cudnnFilterStruct* cudnnFilterDescriptor_t; -typedef struct cudnnLRNStruct* cudnnLRNDescriptor_t; -typedef struct cudnnActivationStruct* cudnnActivationDescriptor_t; -typedef struct cudnnSpatialTransformerStruct* cudnnSpatialTransformerDescriptor_t; -typedef struct cudnnOpTensorStruct* cudnnOpTensorDescriptor_t; -typedef struct cudnnReduceTensorStruct* cudnnReduceTensorDescriptor_t; -typedef struct cudnnCTCLossStruct* cudnnCTCLossDescriptor_t; - -typedef enum {} cudnnConvolutionBwdDataAlgo_t; -typedef enum {} cudnnConvolutionBwdFilterAlgo_t; -typedef enum {} cudnnConvolutionFwdAlgo_t; -#endif // HAVE_CUDA == 1 - -#endif // KALDI_CUDAMATRIX_CU_CUDNN_HELPER_H_ diff --git a/src/cudamatrix/cu-device.h b/src/cudamatrix/cu-device.h index 2ff641d2d1c..95c447f3a7b 100644 --- a/src/cudamatrix/cu-device.h +++ b/src/cudamatrix/cu-device.h @@ -31,11 +31,11 @@ #include #include #include +#include #include #include "base/kaldi-common.h" #include "base/timer.h" #include "cudamatrix/cu-allocator.h" -#include "cudamatrix/cu-cudnn-helper.h" namespace kaldi { diff --git a/src/matrix/kaldi-matrix.cc b/src/matrix/kaldi-matrix.cc index fcfe0616b64..ebdaca77e4f 100644 --- a/src/matrix/kaldi-matrix.cc +++ b/src/matrix/kaldi-matrix.cc @@ -28,7 +28,7 @@ #include "matrix/compressed-matrix.h" #include "matrix/sparse-matrix.h" -static_assert(int(kaldi::kNoTrans) == int(CblasNoTrans) && int(kaldi::kTrans) == int(CblasTrans), +static_assert(int(kaldi::kNoTrans) == int(CblasNoTrans) && int(kaldi::kTrans) == int(CblasTrans), "kaldi::kNoTrans and kaldi::kTrans must be equal to the appropriate CBLAS library constants!"); namespace kaldi { @@ -538,7 +538,7 @@ void MatrixBase::AddMatSmat(Real alpha, const MatrixBase &A, // pass stride to write a column as matrices are stored in row major order. cblas_Xaxpy(this_num_rows, alpha_B_jk, a_col_k, A.stride_, this_col_j, this->stride_); - //for (MatrixIndexT i = 0; i < this_num_rows; ++i) + //for (MatrixIndexT i = 0; i < this_num_rows; ++i) // this_col_j[i*this->stride_] += alpha_B_jk * a_col_k[i*A.stride_]; } } @@ -1656,11 +1656,12 @@ SubMatrix::SubMatrix(const MatrixBase &M, template -SubMatrix::SubMatrix(Real *data, +SubMatrix::SubMatrix(const Real *data, MatrixIndexT num_rows, MatrixIndexT num_cols, MatrixIndexT stride): - MatrixBase(data, num_cols, num_rows, stride) { // caution: reversed order! + MatrixBase(const_cast(data), + num_cols, num_rows, stride) { // caution: reversed order! if (data == NULL) { KALDI_ASSERT(num_rows * num_cols == 0); this->num_rows_ = 0; diff --git a/src/matrix/kaldi-matrix.h b/src/matrix/kaldi-matrix.h index a973824128c..a67f75260bf 100644 --- a/src/matrix/kaldi-matrix.h +++ b/src/matrix/kaldi-matrix.h @@ -952,9 +952,10 @@ class SubMatrix : public MatrixBase { const MatrixIndexT co, // column offset, 0 < co < NumCols() const MatrixIndexT c); // number of columns, c > 0 - // This initializer is mostly intended for use in CuMatrix and related - // classes. Be careful! - SubMatrix(Real *data, + // This initializer does not take ownership of the pointer, and to use it you + // need to have some understanding of how this library works. Caution: + // it can be used to get around const limitations, so be careful. + SubMatrix(const Real *data, MatrixIndexT num_rows, MatrixIndexT num_cols, MatrixIndexT stride); diff --git a/src/nnet3/Makefile b/src/nnet3/Makefile index 2c3f41d8aa1..b39e3d64f31 100644 --- a/src/nnet3/Makefile +++ b/src/nnet3/Makefile @@ -31,7 +31,8 @@ OBJFILES = nnet-common.o nnet-compile.o nnet-component-itf.o \ nnet-compile-looped.o decodable-simple-looped.o \ decodable-online-looped.o convolution.o \ nnet-convolutional-component.o attention.o \ - nnet-attention-component.o nnet-tdnn-component.o nnet-batch-compute.o + nnet-attention-component.o nnet-tdnn-component.o \ + nnet-batch-compute.o convolution-cudnn.o ifeq ($(CUDA), true) @@ -45,6 +46,6 @@ ADDLIBS = ../chain/kaldi-chain.a ../cudamatrix/kaldi-cudamatrix.a \ ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \ - ../base/kaldi-base.a + ../base/kaldi-base.a include ../makefiles/default_rules.mk diff --git a/src/nnet3/convolution-cudnn.cc b/src/nnet3/convolution-cudnn.cc index f39db408129..4c5a69d3532 100644 --- a/src/nnet3/convolution-cudnn.cc +++ b/src/nnet3/convolution-cudnn.cc @@ -1,6 +1,7 @@ // nnet3/convolution-cudnn.cc // Copyright 2018 Daniel Galvez +// 2018 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // @@ -21,7 +22,7 @@ namespace kaldi { namespace nnet3 { -namespace cudnn { +namespace cudnn_convolution { namespace { @@ -34,12 +35,37 @@ const BaseFloat ZERO(0.0); ConvolutionComputation:: ConvolutionComputation(int32 num_channels_out, int32 num_channels_in, int32 filter_height, int32 filter_width, - int32 filter_stride_height, int32 filter_stride_width, + int32 filter_stride_vertical, int32 filter_stride_horizontal, int32 filter_dilation_height, int32 filter_dilation_width, int32 num_images, int32 input_image_height, int32 input_image_width, - int32 zero_padding_height, int32 zero_padding_width) { + int32 zero_padding_height, int32 zero_padding_width): + num_channels_out_(num_channels_out), + num_channels_in_(num_channels_in), + filter_height_(filter_height), + filter_width_(filter_width), + filter_stride_vertical_(filter_stride_vertical), + filter_stride_horizontal_(filter_stride_horizontal), + filter_dilation_height_(filter_dilation_height), + filter_dilation_width_(filter_dilation_width), + num_images_(num_images), + input_image_height_(input_image_height), + input_image_width_(input_image_width), + zero_padding_height_(zero_padding_height), + zero_padding_width_(zero_padding_width) { +#if HAVE_CUDA == 1 + if (CuDevice::Instantiate().Enabled()) { + InitCudnn(); + } +#endif + // The following is called whether or not we are using CUDA. + ComputeOutputImageHeight(); + ComputeOutputImageWidth(); +} + +#if HAVE_CUDA == 1 +void ConvolutionComputation::InitCudnn() { CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&input_desc_)); CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&output_desc_)); CUDNN_SAFE_CALL(cudnnCreateFilterDescriptor(¶ms_desc_)); @@ -49,20 +75,20 @@ ConvolutionComputation(int32 num_channels_out, int32 num_channels_in, CUDNN_SAFE_CALL( cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NHWC, - CUDNN_DATA_FLOAT, num_images, - num_channels_in, input_image_width, - input_image_height)); + CUDNN_DATA_FLOAT, num_images_, + num_channels_in_, input_image_width_, + input_image_height_)); CUDNN_SAFE_CALL( cudnnSetConvolution2dDescriptor(conv_desc_, - zero_padding_width, zero_padding_height, - filter_stride_width, filter_stride_height, - filter_dilation_width, filter_dilation_height, + zero_padding_width_, zero_padding_height_, + filter_stride_horizontal_, filter_stride_vertical_, + filter_dilation_width_, filter_dilation_height_, CUDNN_CROSS_CORRELATION, // TODO: Double check this! CUDNN_DATA_FLOAT)); CUDNN_SAFE_CALL( cudnnSetFilter4dDescriptor(params_desc_, CUDNN_DATA_FLOAT, - CUDNN_TENSOR_NCHW, num_channels_out, - num_channels_in, filter_width, filter_height)); + CUDNN_TENSOR_NCHW, num_channels_out_, + num_channels_in_, filter_width_, filter_height_)); // These two member functions depend only on input_desc_, // conv_desc_, and params_desc_, so they are safe to call now. @@ -70,13 +96,13 @@ ConvolutionComputation(int32 num_channels_out, int32 num_channels_in, int32 out_kaldi_width_cudnn_height = OutputImageWidth(); CUDNN_SAFE_CALL( cudnnSetTensor4dDescriptor(output_desc_, CUDNN_TENSOR_NHWC, - CUDNN_DATA_FLOAT, num_images, - num_channels_in, out_kaldi_width_cudnn_height, + CUDNN_DATA_FLOAT, num_images_, + num_channels_in_, out_kaldi_width_cudnn_height, out_kaldi_height_cudnn_width)); const int32 bias_stride[] = {1}; CUDNN_SAFE_CALL( cudnnSetTensorNdDescriptor(bias_desc_, CUDNN_DATA_FLOAT, 1, - &num_channels_out, bias_stride)); + &num_channels_out_, bias_stride)); const double DONT_CARE = 0; CUDNN_SAFE_CALL( @@ -144,9 +170,43 @@ ConvolutionComputation(int32 num_channels_out, int32 num_channels_in, backward_data_results[0]; bwd_data_algo_ = best_backward_data.algo; delete [] backward_data_results; + ComputeTempSpaceSizes(); } +#endif -ConvolutionComputation::~ConvolutionComputation() { +#if HAVE_CUDA == 1 +void ConvolutionComputation::ComputeTempSpaceSizes() { + CUDNN_SAFE_CALL(cudnnGetConvolutionForwardWorkspaceSize( + CuDevice::Instantiate().GetCudnnHandle(), + input_desc_, + params_desc_, + conv_desc_, + output_desc_, + fwd_algo_, + &temp_space_required_forward_)); + + CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize( + CuDevice::Instantiate().GetCudnnHandle(), + params_desc_, + output_desc_, + conv_desc_, + input_desc_, + bwd_data_algo_, + &temp_space_required_backward_data_)); + + CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize( + CuDevice::Instantiate().GetCudnnHandle(), + input_desc_, + output_desc_, + conv_desc_, + params_desc_, + bwd_filter_algo_, + &temp_space_required_backward_filter_)); +} +#endif + +#if HAVE_CUDA == 1 +void ConvolutionComputation::DestroyCudnn() { CUDNN_SAFE_CALL(cudnnDestroyTensorDescriptor(input_desc_)); CUDNN_SAFE_CALL(cudnnDestroyTensorDescriptor(output_desc_)); CUDNN_SAFE_CALL(cudnnDestroyFilterDescriptor(params_desc_)); @@ -154,8 +214,36 @@ ConvolutionComputation::~ConvolutionComputation() { CUDNN_SAFE_CALL(cudnnDestroyConvolutionDescriptor(conv_desc_)); CUDNN_SAFE_CALL(cudnnDestroyActivationDescriptor(activation_desc_)); } +#endif -int32 ConvolutionComputation::OutputImageHeight() const { +ConvolutionComputation::~ConvolutionComputation() { +#if HAVE_CUDA == 1 + if (CuDevice::Instantiate().Enabled()) + DestroyCudnn(); +#endif +} + +void ConvolutionComputation::ComputeOutputImageHeight() { + // 'filter_height_reduction' is the amount by which the height of the filter patch + // reduces the effective height of the input image. It's the distance between + // the first and last pixels of the filter patch. E.g. in a 3x3 kernel it + // would be 2. + int32 filter_height_reduction = (filter_height_ - 1) * filter_dilation_height_; + // 'modified_input_height' is the number of times we can shift the filter patch + // (not yet taking account of any filter stride). It's a kind of augmented input-image + // height, after applying zero-padding and subtracting filter_height_reduction. + int32 modified_input_height = + input_image_height_ - filter_height_reduction + (zero_padding_height_ * 2), + s = filter_stride_vertical_; + + // output_image_height_ equals reduced_input_height divided by s (but rounding + // up), which is the number of times we can shift the filter patch by + // filter_stride_vertical_. + output_image_height_ = (modified_input_height + s - 1) / s; + +#if HAVE_CUDA == 1 + // Check that CUDA has the same idea of what the output image height is, as we + // do. This helps check that the CPU and GPU computations are compatible. int32 unused; int32 kaldi_height_cudnn_width; CUDNN_SAFE_CALL( @@ -164,10 +252,31 @@ int32 ConvolutionComputation::OutputImageHeight() const { &unused, &unused, &unused, &kaldi_height_cudnn_width)); - return kaldi_height_cudnn_width; + if (kaldi_height_cudnn_width != output_image_height_) { + KALDI_ERR << "Code error: the height from CUDNN " << kaldi_height_cudnn_width + << " does not match our value " << output_image_height_; + } +#endif } -int32 ConvolutionComputation::OutputImageWidth() const { +void ConvolutionComputation::ComputeOutputImageWidth() { + // 'filter_width_reduction' is the amount by which the width of the filter patch + // reduces the effective width of the input image. It's the distance between + // the first and last pixels of the filter patch. E.g. in a 3x3 kernel it + // would be 2. + int32 filter_width_reduction = (filter_width_ - 1) * filter_dilation_width_; + // 'modified_input_width' is the number of times we can shift the filter patch + // (not yet taking account of any filter stride). It's a kind of augmented input-image + // width, after applying zero-padding and subtracting filter_width_reduction. + int32 modified_input_width = + input_image_width_ - filter_width_reduction + (zero_padding_width_ * 2), + s = filter_stride_horizontal_; + + // output_image_width equals reduced_input_width divided by s (but rounding + // up), which is the number of times we can shift the filter patch by + // filter_stride_horizontal_. + output_image_width_ = (modified_input_width + s - 1) / s; +#if HAVE_CUDA == 1 int32 unused; int32 kaldi_width_cudnn_height; CUDNN_SAFE_CALL( @@ -176,135 +285,287 @@ int32 ConvolutionComputation::OutputImageWidth() const { &unused, &unused, &kaldi_width_cudnn_height, &unused)); - return kaldi_width_cudnn_height; + if (kaldi_width_cudnn_height != output_image_width_) { + KALDI_ERR << "Code error: the height from CUDNN " << kaldi_width_cudnn_height + << " does not match our value " << output_image_width_; + } +#endif } -size_t ConvolutionComputation::TempSpaceRequiredForward() const { - size_t workspace_size_bytes; - CUDNN_SAFE_CALL(cudnnGetConvolutionForwardWorkspaceSize( - CuDevice::Instantiate().GetCudnnHandle(), - input_desc_, - params_desc_, - conv_desc_, - output_desc_, - fwd_algo_, - &workspace_size_bytes)); - return workspace_size_bytes; -} -size_t ConvolutionComputation::TempSpaceRequiredBackwardData() const { - size_t workspace_size_bytes; - CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize( - CuDevice::Instantiate().GetCudnnHandle(), - params_desc_, - output_desc_, - conv_desc_, - input_desc_, - bwd_data_algo_, - &workspace_size_bytes)); - return workspace_size_bytes; +void ConvolutionComputation::Write(std::ostream &os, bool binary) const { + // TODO: write just num_channels_out_ through zero_padding_width_; + } +void ConvolutionComputation::Read(std::istream &is, bool binary) { + // TODO: read just num_channels_out_ through zero_padding_width_; -size_t ConvolutionComputation::TempSpaceRequiredBackwardFilter() const { - size_t workspace_size_bytes; - CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize( - CuDevice::Instantiate().GetCudnnHandle(), - input_desc_, - output_desc_, - conv_desc_, - params_desc_, - bwd_filter_algo_, - &workspace_size_bytes)); - return workspace_size_bytes; +#if HAVE_CUDA == 1 + if (CuDevice::Instantiate().Enabled()) { + InitCudnn(); + } +#endif + // The following are called whether or not we have CUDA. + ComputeOutputImageHeight(); + ComputeOutputImageWidth(); } - void ConvolutionComputation:: ConvolveForward(const CuMatrixBase &input, const CuMatrixBase ¶ms, const CuVectorBase &bias, - CuVectorBase *temp_space, CuMatrixBase *output) const { - CUDNN_SAFE_CALL(cudnnConvolutionBiasActivationForward( - CuDevice::Instantiate().GetCudnnHandle(), - &ONE, - input_desc_, - input.Data(), - params_desc_, - params.Data(), - conv_desc_, - fwd_algo_, - temp_space->Data(), - temp_space->Dim() * sizeof(BaseFloat), - &ZERO, - output_desc_, - output->Data(), - bias_desc_, - bias.Data(), - activation_desc_, - output_desc_, - output->Data())); + // Check some dimensions. + KALDI_ASSERT( + input.NumRows() == num_images_ * input_image_width_ && + input.NumCols() == input_image_height_ * num_channels_in_ && + input.Stride() == input.NumCols() && + params.NumRows() == num_channels_out_ && + params.NumCols() == num_channels_in_ * filter_height_ * filter_width_ && + params.Stride() == params.NumCols() && + bias.Dim() == num_channels_out_ && + output->NumRows() == num_images_ * input_image_height_ && + output->NumCols() == input_image_width_ * num_channels_out_ && + output->Stride() == output->NumCols()); + +#ifdef HAVE_CUDNN + if (CuDevice::Instantiate().Enabled()) { + CuVector temp_space(temp_space_required_forward_ / + sizeof(BaseFloat), kUndefined); + CUDNN_SAFE_CALL(cudnnConvolutionBiasActivationForward( + CuDevice::Instantiate().GetCudnnHandle(), + &ONE, + input_desc_, + input.Data(), + params_desc_, + params.Data(), + conv_desc_, + fwd_algo_, + temp_space.Data(), + temp_space.Dim() * sizeof(BaseFloat), + &ZERO, + output_desc_, + output->Data(), + bias_desc_, + bias.Data(), + activation_desc_, + output_desc_, + output->Data())); + } else +#endif + { + ConvolveForward(input.Mat(), params.Mat(), bias.Vec(), + &(output->Mat())); + } +} + + +void ConvolutionComputation:: +ConvolveForward(const MatrixBase &input, + const MatrixBase ¶ms, + const VectorBase &bias, + MatrixBase *output) const { + // Check some dimensions. + KALDI_ASSERT( + input.NumRows() == num_images_ * input_image_width_ && + input.NumCols() == input_image_height_ * num_channels_in_ && + input.Stride() == input.NumCols() && + params.NumRows() == num_channels_out_ && + params.NumCols() == num_channels_in_ * filter_height_ * filter_width_ && + params.Stride() == params.NumCols() && + bias.Dim() == num_channels_out_ && + output->NumRows() == num_images_ * input_image_height_ && + output->NumCols() == input_image_width_ * num_channels_out_ && + output->Stride() == output->NumCols()); + + + { // Deal with the bias. + SubMatrix output_rearranged( + output->Data(), + num_images_ * input_image_width_ * input_image_height_, + num_channels_out_, num_channels_out_); + output_rearranged.CopyRowsFromVec(bias); + } + + Matrix params_rearranged(filter_width_ * filter_height_, + num_channels_out_ * num_channels_in_, + kUndefined, kStrideEqualNumCols); + ConvertParams(params, ¶ms_rearranged); + + // We're using variable names w (as in width) for horizontal positions and h + // (as in height) for vertical positions. This is perhaps not ideal. + for (int32 output_w = 0; output_w < output_image_width_; output_w++) { + for (int32 output_h = 0; output_h < output_image_height_; output_h++) { + for (int32 filter_h = 0; filter_h < filter_height_; filter_h++) { + int32 filter_h_flipped = filter_height_ - 1 - filter_h; + int32 input_h = output_h * filter_stride_vertical_ + - zero_padding_height_ + + filter_h * filter_dilation_height_; + if (input_h < 0 || input_h >= input_image_height_) + continue; + for (int32 filter_w = 0; filter_w < filter_width_; filter_w++) { + int32 filter_w_flipped = filter_width_ - 1 - filter_w; + int32 input_w = output_w * filter_stride_horizontal_ + - zero_padding_width_ + + filter_w * filter_dilation_width_; + + if (input_w < 0 || input_w >= input_image_width_) + continue; + + const BaseFloat *params_data = params_rearranged.RowData( + filter_w_flipped * filter_height_ + filter_h_flipped); + SubMatrix this_params(params_data, + num_channels_out_, + num_channels_in_, num_channels_in_); + const BaseFloat *input_data = input.Data() + + input_w * input_image_height_ * num_channels_in_ + + input_h * num_channels_in_; + SubMatrix this_input_pixel(input_data, + num_images_, + num_channels_in_, + num_channels_in_); + SubMatrix this_output_pixel(input_data, + num_images_, + num_channels_in_, + num_channels_in_); + this_output_pixel.AddMatMat(1.0, this_input_pixel, kNoTrans, + this_params, kTrans, 1.0); + } + } + } + } } void ConvolutionComputation:: ConvolveBackwardData(const CuMatrixBase ¶ms, const CuMatrixBase &output_deriv, - CuVectorBase *temp, CuMatrixBase *input_deriv) const { - CUDNN_SAFE_CALL(cudnnConvolutionBackwardData( - CuDevice::Instantiate().GetCudnnHandle(), - &ONE, - params_desc_, - params.Data(), - output_desc_, - output_deriv.Data(), - conv_desc_, - bwd_data_algo_, - temp->Data(), - temp->Dim() * sizeof(BaseFloat), - &ZERO, - input_desc_, - input_deriv->Data())); +#ifdef HAVE_CUDNN + if (CuDevice::Instantiate().Enabled()) { + CuVector temp_space(temp_space_required_backward_data_ / + sizeof(BaseFloat), kUndefined); + CUDNN_SAFE_CALL(cudnnConvolutionBackwardData( + CuDevice::Instantiate().GetCudnnHandle(), + &ONE, + params_desc_, + params.Data(), + output_desc_, + output_deriv.Data(), + conv_desc_, + bwd_data_algo_, + temp_space.Data(), + temp_space.Dim() * sizeof(BaseFloat), + &ZERO, + input_desc_, + input_deriv->Data())); + } else +#endif + { + // TODO + } } +void ConvolutionComputation:: +ConvolveBackwardData(const MatrixBase ¶ms, + const MatrixBase &output_deriv, + MatrixBase *input_deriv) const { + // TODO +} + + + void ConvolutionComputation:: ConvolveBackwardParams(const CuMatrixBase &output_deriv, const CuMatrixBase &input, BaseFloat alpha, - CuVectorBase *temp, CuMatrixBase *params_deriv) const { - CUDNN_SAFE_CALL(cudnnConvolutionBackwardFilter( - CuDevice::Instantiate().GetCudnnHandle(), - &alpha, - input_desc_, - input.Data(), - output_desc_, - output_deriv.Data(), - conv_desc_, - bwd_filter_algo_, - temp->Data(), - temp->Dim() * sizeof(BaseFloat), - &ONE, - params_desc_, - params_deriv->Data())); +#ifdef HAVE_CUDNN + if (CuDevice::Instantiate().Enabled()) { + CuVector temp_space(temp_space_required_backward_params_ / + sizeof(BaseFloat), kUndefined); + CUDNN_SAFE_CALL(cudnnConvolutionBackwardFilter( + CuDevice::Instantiate().GetCudnnHandle(), + &alpha, + input_desc_, + input.Data(), + output_desc_, + output_deriv.Data(), + conv_desc_, + bwd_filter_algo_, + temp_space.Data(), + temp_space.Dim() * sizeof(BaseFloat), + &ONE, + params_desc_, + params_deriv->Data())); + } else +#endif + { + ConvolveBackwardParams(output_deriv.Mat(), input.Mat(), + alpha, &(params_deriv->Mat())); + } } + +void ConvolutionComputation:: +ConvolveBackwardParams(const MatrixBase &output_deriv, + const MatrixBase &input, + BaseFloat alpha, + MatrixBase *params_deriv) const { + // TODO +} + + void ConvolutionComputation:: ConvolveBackwardBias(const CuMatrixBase &output_deriv, BaseFloat alpha, CuVectorBase *bias_deriv) const { - CUDNN_SAFE_CALL(cudnnConvolutionBackwardBias( - CuDevice::Instantiate().GetCudnnHandle(), - &alpha, - output_desc_, - output_deriv.Data(), - &ONE, - bias_desc_, - bias_deriv->Data())); +#ifdef HAVE_CUDNN + if (CuDevice::Instantiate().Enabled()) { + CUDNN_SAFE_CALL(cudnnConvolutionBackwardBias( + CuDevice::Instantiate().GetCudnnHandle(), + &alpha, + output_desc_, + output_deriv.Data(), + &ONE, + bias_desc_, + bias_deriv->Data())); + } else +#endif + { + ConvolveBackwardBias(output_deriv.Mat(), alpha, &(bias_deriv->Vec())); + } } -} // namespace cudnn -} // namespace nnet3 -} // namespace kaldi +void ConvolutionComputation:: +ConvolveBackwardBias(const MatrixBase &output_deriv, + BaseFloat alpha, + VectorBase *bias_deriv) const { + // TODO. +} + + +// This function, called only if we are not using the GPU, converts +// the params from KCWH format to WHKC format (which is more convenient +// when using the CPU. Note: K == channels-out, C == channels-in. +void ConvolutionComputation::ConvertParams( + const MatrixBase ¶ms, + MatrixBase *params_rearranged) const { + KALDI_ASSERT(params.NumRows() == num_channels_out_ && + params.Stride() == params.NumCols() && + params_rearranged->NumRows() == filter_width_ * filter_height_ && + params_rearranged->Stride() == params_rearranged->NumCols()); + + // Reinterpret params as params_reinterpret which is of dimension KC * WH (instead of K * CWH). + SubMatrix params_reinterpret(params.Data(), + num_channels_out_ * num_channels_in_, + filter_width_ * filter_height_, + filter_width_ * filter_height_); + params_rearranged->CopyFromMat(params_reinterpret, kTrans); +} + + +} // namespace cudnn_convolution +} // namespace nnet3 +} // namespace kaldi diff --git a/src/nnet3/convolution-cudnn.h b/src/nnet3/convolution-cudnn.h index 8d3e6cdea21..828150af035 100644 --- a/src/nnet3/convolution-cudnn.h +++ b/src/nnet3/convolution-cudnn.h @@ -1,6 +1,7 @@ // nnet3/convolution-cudnn.h // Copyright 2018 Daniel Galvez +// 2018 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // @@ -24,12 +25,14 @@ #include "base/kaldi-common.h" #include "matrix/matrix-lib.h" #include "nnet3/convolution.h" +#if HAVE_CUDA == 1 +#include +#endif -#include namespace kaldi { namespace nnet3 { -namespace cudnn { +namespace cudnn_convolution { /** Represents structural information about a convolution computation, with @@ -40,16 +43,15 @@ namespace cudnn { re-used between different minibatches. This object is lightweight; it doesn't contain data, only a few integers and descriptors. - In the following docstrings, consider: - N to be equivalent to num_images - C to be equivalent to num_channels_in - K to be equivalent to num_channels_out - H to be equivalent to input_image_height (for images) or + In the following docstrings: + N is equivalent to num_images + C is equivalent to num_channels_in + K is equivalent to num_channels_out + H is equivalent to input_image_height or output_image_height (for images) or filter_height (for filter parameters). - W to be equivalent to input_image_width (for images) or + W is equivalent to input_image_width or output_image_width (for images) or filter_width (for filter parameters). - @param [in] num_channels_out Number of output channels, e.g. 64. @param [in] num_channels_in Number of input channels, e.g. 32. @param [in] filter_height Height of filter patch, e.g. 3 (for 3x3 kernel). Corresponds @@ -57,17 +59,17 @@ namespace cudnn { height in OCR applications. @param [in] filter_width Width of filter patch, e.g. 3 (for 3x3 kernel). Corresponds to the 'time' dimension in normal speech applications. - @param [in] filter_stride_height Filter stride in the height ('frequency') dimension. + @param [in] filter_stride_vertical Filter stride in the vertical ('frequency') dimension. Will normally be 1 in speech and OCR applications. - @param [in] filter_stride_width Filter stride in the width ('time') dimension. + @param [in] filter_stride_horizontal Filter stride in the horizontal ('time') dimension. Will usually be 1 in most layers, but may be 2 or 3 if we are doing subsampling on this layer (e.g. in reduced-frame-rate models like chain models). - @param [in] filter_dilation_height Filter dilation in the height ('frequency') + @param [in] filter_dilation_height Filter dilation in the vertical ('frequency') dimension. Equals the stride, in the input image, of individual elements of the filter patch. Will normally be 1. - @param [in] filter_dilation_width Filter dilation in the width ('time') + @param [in] filter_dilation_width Filter dilation in the horizontal ('time') dimension. Will normally be 1, but could be more than one if, for instance, you have components with time-stride > 1 which for some reason are required @@ -89,29 +91,22 @@ namespace cudnn { dimension). Likely to be 0 in many speech applications, since we normally deal with edge effects by padding with repeats of the first and last frame; but - padding is supported by the component. + padding is supported by this object. */ class ConvolutionComputation final { public: ConvolutionComputation(int32 num_channels_out, int32 num_channels_in, int32 filter_height, int32 filter_width, - int32 filter_stride_height, int32 filter_stride_width, + int32 filter_stride_vertical, int32 filter_stride_horizontal, int32 filter_dilation_height, int32 filter_dilation_width, int32 num_images, int32 input_image_height, int32 input_image_width, int32 zero_padding_height, int32 zero_padding_width); ~ConvolutionComputation(); - int32 OutputImageHeight() const; - int32 OutputImageWidth() const; - /** - * Returns the size of the workspace required for each stage, in - * bytes (_not_ 32-bit words). - */ - size_t TempSpaceRequiredForward() const; - size_t TempSpaceRequiredBackwardData() const; - size_t TempSpaceRequiredBackwardFilter() const; + int32 OutputImageHeight() const { return output_image_height_; } + int32 OutputImageWidth() const { return output_image_width_; } /** * For an explanation of the notation below (e.g. NWHC), see the @@ -123,49 +118,42 @@ class ConvolutionComputation final { * versa. This is not visible to the user; we mention it just in case * those familiar with CUDNN get surprised at the order * - * @param [in] input NWHC fully-packed tensor, with N == NumRows() - * @param [in] params KCWH fully-packed tensor, with K == NumRows() + * @param [in] input NWHC fully-packed tensor, with NumRows() == N * W + * @param [in] params KCWH fully-packed tensor, with NumRows() == K. * @param [in] bias vector of length K - * @param [in/out] temp_space Pointer to pre-allocated memory of size at least - * this->TempSpaceRequiredForward() bytes * @param [out] output Pre-allocated NWHK fully-packed tensor, with N == NumRows() */ void ConvolveForward(const CuMatrixBase &input, const CuMatrixBase ¶ms, const CuVectorBase &bias, - CuVectorBase *temp_space, CuMatrixBase *output) const; /** - * @param [in] params KCWH fully-packed tensor, with K == NumRows() - * @param [in] output_deriv NWHK fully-packed tensor, with N == NumRows() - * @param [in/out] temp_space Pointer to pre-allocated memory of size at least - * this->TempSpaceRequiredBackwardData() bytes - * @param [out] input_deriv Pre-allocated NWHC fully-packed tensor, with N == NumRows() + * @param [in] params KCWH fully-packed tensor, with NumRows() == K + * @param [in] output_deriv NWHK fully-packed tensor, with NumRows() == N * W + * @param [out] input_deriv Pre-allocated NWHC fully-packed tensor, with + * NumRows() == N * W */ void ConvolveBackwardData(const CuMatrixBase ¶ms, const CuMatrixBase &output_deriv, - CuVectorBase *temp_space, CuMatrixBase *input_deriv) const; /** - * @param [in] output_deriv NWHK fully-packed tensor, with N == NumRows() - * @param [in] input NWHC fully-packed tensor, with N == NumRows() + * @param [in] output_deriv NWHK fully-packed tensor, with NumRows() == N * W. + * @param [in] input NWHC fully-packed tensor, with NumRows() == N * W. * @param [in] alpha * params_deriv := alpha * gradient_computed + params_deriv - * @param [in] params KCWH fully-packed tensor, with K == NumRows() - * @param [in/out] temp_space Pointer to pre-allocated memory of size at least - * this->TempSpaceRequiredBackwardFilter() bytes - * @param [out] params_deriv Pre-allocated KCWH fully-packed tensor, with K == NumRows() + * @param [in] params KCWH fully-packed tensor, with NumRows() == K + * @param [out] params_deriv Pre-allocated KCWH fully-packed tensor, + * with NumRows() == K. */ void ConvolveBackwardParams(const CuMatrixBase &output_deriv, const CuMatrixBase &input, BaseFloat alpha, - CuVectorBase *temp_space, CuMatrixBase *params_deriv) const; /** - * @param [in] output_deriv NWHK fully-packed tensor, with N == NumRows() + * @param [in] output_deriv NWHK fully-packed tensor, with NumRows() * N * W. * @param [in] alpha * bias_deriv := alpha * gradient_computed + bias_deriv * @param [out] bias_deriv Pre-allocated vector of length K @@ -174,7 +162,86 @@ class ConvolutionComputation final { BaseFloat alpha, CuVectorBase *bias_deriv) const; + // The CPU versions of the functions declared above allow the user to use the + // CPU even if a GPU is active. They are also called by the versions of the + // same name that take CuMatrix types, in the case when either we did not + // compile for CUDA or we did but we are not using a GPU. + void ConvolveForward(const MatrixBase &input, + const MatrixBase ¶ms, + const VectorBase &bias, + MatrixBase *output) const; + void ConvolveBackwardData(const MatrixBase ¶ms, + const MatrixBase &output_deriv, + MatrixBase *input_deriv) const; + void ConvolveBackwardParams(const MatrixBase &output_deriv, + const MatrixBase &input, + BaseFloat alpha, + MatrixBase *params_deriv) const; + void ConvolveBackwardBias(const MatrixBase &output_deriv, + BaseFloat alpha, + VectorBase *bias_deriv) const; + + + + + void Write(std::ostream &os, bool binary) const; + + void Read(std::istream &os, bool binary); + private: +#if HAVE_CUDA == 1 + // initialize the various descriptors; only called if compiled for CUDA + // AND we are using a GPU. + void InitCudnn(); + // ComputeTempSpaceSizes() is called from InitCudnn(); it sets the + // temp_space_*_ member variables. + void ComputeTempSpaceSizes(); + // Destroy the descriptors. + void DestroyCudnn(); +#endif + + // Called from the constructor and Read(), this sets output_image_height_. + void ComputeOutputImageHeight(); + // Called from the constructor and Read(), this sets output_image_width_. + void ComputeOutputImageWidth(); + + + // This function, called only if we are not using the GPU, converts + // the params from KCWH format to WHKC format (which is more convenient + // when using the CPU. params and params_rearranged must both be + // packed (Stride() == NumCols()), params must have num-rows equal to K + // (num_channels_out_), and params_rearranged must have num-rows equal + // to to WH (filter_width_ * filter_height_). + void ConvertParams(const MatrixBase ¶ms, + MatrixBase *params_rearranged) const; + // This function does the opposite transformation of what ConvertParams() + // does. + void ConvertParamsBack(const MatrixBase ¶ms_rearranged, + MatrixBase *params) const; + + + + // The following block of members are just copies of the args to the + // constructor. Please see the documentation of the constructor, and look for + // the similarly named parameter, to understand the meaning of these + // individual members. + int32 num_channels_out_; + int32 num_channels_in_; + int32 filter_height_; + int32 filter_width_; + int32 filter_stride_vertical_; + int32 filter_stride_horizontal_; + int32 filter_dilation_height_; + int32 filter_dilation_width_; + int32 num_images_; + int32 input_image_height_; + int32 input_image_width_; + int32 zero_padding_height_; + int32 zero_padding_width_; + int32 output_image_height_; + int32 output_image_width_; + +#if HAVE_CUDA == 1 cudnnTensorDescriptor_t input_desc_; cudnnTensorDescriptor_t output_desc_; cudnnFilterDescriptor_t params_desc_; @@ -185,10 +252,18 @@ class ConvolutionComputation final { cudnnConvolutionFwdAlgo_t fwd_algo_; cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo_; cudnnConvolutionBwdDataAlgo_t bwd_data_algo_; + + // The units of the following are all in bytes. + size_t temp_space_required_forward_; + size_t temp_space_required_backward_data_; + size_t temp_space_required_backward_filter_; +#endif + + }; -} // namespace cudnn -} // namespace nnet3 -} // namespace kaldi +} // namespace cudnn_convolution +} // namespace nnet3 +} // namespace kaldi #endif // KALDI_NNET3_NNET_CUDNN_CONVOLUTION_H_ From 700085028ceca2203f6ff2f0dd4fc75b61f89774 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 16 Oct 2018 00:19:55 -0400 Subject: [PATCH 13/22] [src] Some refactoring; start adding tests. --- src/nnet3/Makefile | 3 +- src/nnet3/convolution-cudnn-test.cc | 483 ++++++++++++++++++++++++++++ src/nnet3/convolution-cudnn.cc | 368 +++++++++++---------- src/nnet3/convolution-cudnn.h | 216 +++++++------ tools/Makefile | 8 +- 5 files changed, 798 insertions(+), 280 deletions(-) create mode 100644 src/nnet3/convolution-cudnn-test.cc diff --git a/src/nnet3/Makefile b/src/nnet3/Makefile index b39e3d64f31..710c019cd56 100644 --- a/src/nnet3/Makefile +++ b/src/nnet3/Makefile @@ -12,7 +12,8 @@ TESTFILES = natural-gradient-online-test nnet-graph-test \ nnet-compile-utils-test nnet-nnet-test nnet-utils-test \ nnet-compile-test nnet-analyze-test nnet-compute-test \ nnet-optimize-test nnet-derivative-test nnet-example-test \ - nnet-common-test convolution-test attention-test + nnet-common-test convolution-test attention-test \ + convolution-cudnn-test OBJFILES = nnet-common.o nnet-compile.o nnet-component-itf.o \ nnet-simple-component.o nnet-normalize-component.o \ diff --git a/src/nnet3/convolution-cudnn-test.cc b/src/nnet3/convolution-cudnn-test.cc new file mode 100644 index 00000000000..559857abf7b --- /dev/null +++ b/src/nnet3/convolution-cudnn-test.cc @@ -0,0 +1,483 @@ +// nnet3/convolution-cudnn-test.cc + +// Copyright 2018 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "nnet3/convolution-cudnn.h" +#include "util/common-utils.h" + +namespace kaldi { +namespace nnet3 { +namespace cudnn_convolution { + +// for testing purposes, create a random ConvolutionComputation +static void GetRandomConfig(ConvolutionComputationConfig *config) { + config->num_images = RandInt(1, 10); + config->num_channels_out = RandInt(1, 10); + config->num_channels_in = RandInt(1, 10); + + config->filter_height = RandInt(1, 3); + config->filter_width = RandInt(1, 3); + + // TODO: randomize the following as well. For now we just use + // default values. + config->filter_stride_vertical = 1; + config->filter_stride_horizontal = 1; + config->filter_dilation_vertical = 1; + config->filter_dilation_horizontal = 1; + + config->input_image_height = RandInt(3, 10); + config->input_image_width = RandInt(3, 10); + + config->zero_padding_vertical = 0; + config->zero_padding_horizontal = 0; + + config->Check(); + config->ComputeOutputImageSize(); +} + +void TestConvolutionComputationConfig() { + for (int32 i = 0; i < 10; i++) { + ConvolutionComputationConfig config; + GetRandomConfig(&config); + std::ostringstream os; + bool binary = true; + config.Write(os, binary); + + ConvolutionComputationConfig config2; + std::istringstream is(os.str()); + config2.Read(is, binary); + std::ostringstream os2; + config2.Write(os2, binary); + KALDI_ASSERT(os.str() == os2.str()); + } +} + +void TestConvolutionComputation() { + for (int32 i = 0; i < 10; i++) { + ConvolutionComputationConfig config; + GetRandomConfig(&config); + ConvolutionComputation computation(config); + + std::ostringstream os; + bool binary = true; + computation.Write(os, binary); + + ConvolutionComputation computation2; + std::istringstream is(os.str()); + computation2.Read(is, binary); + std::ostringstream os2; + computation2.Write(os2, binary); + KALDI_ASSERT(os.str() == os2.str()); + } +} + + +/** + + +// for testing purposes, create a set of input and output indexes for +// a convolution computation that are computable given this model. +static void GetRandomConvolutionIndexes(const ConvolutionModel &model, + std::vector *input_indexes, + std::vector *output_indexes) { + KALDI_ASSERT(model.Check()); + + std::vector > n_x_pairs; + int32 num_n_x_pairs = RandInt(1, 3); + for (int32 i = 0; i < num_n_x_pairs; i++) { + int32 n = RandInt(0, 3), x = RandInt(0, 1); + n_x_pairs.push_back(std::pair(n, x)); + } + SortAndUniq(&n_x_pairs); + num_n_x_pairs = n_x_pairs.size(); + + + // 'output_t_values' is the set of *possible* output + // t values; we'll later sub-sample from these. + std::vector output_t_values; + + { + int32 out_t_start = RandInt(-5, 5), out_t_step = RandInt(1, 3), + num_t_out = RandInt(1, 4); + for (int32 i = 0; i < num_t_out; i++) + output_t_values.push_back(out_t_start + i * out_t_step); + } + + input_indexes->clear(); + output_indexes->clear(); + for (size_t i = 0; i < n_x_pairs.size(); i++) { + std::vector chosen_output_t_values; + while (chosen_output_t_values.empty()) { + for (size_t j = 0; j < output_t_values.size(); j++) + if (RandInt(0, 1) != 0) + chosen_output_t_values.push_back(output_t_values[j]); + } + KALDI_ASSERT(IsSortedAndUniq(chosen_output_t_values)); + + std::set required_input_t_values, + usable_input_t_values; + for (size_t j = 0; j < chosen_output_t_values.size(); j++) { + std::set::const_iterator iter; + int32 t_out = chosen_output_t_values[j]; + for (iter = model.required_time_offsets.begin(); + iter != model.required_time_offsets.end(); iter++) { + int32 offset = *iter; + required_input_t_values.insert(t_out + offset); + } + for (iter = model.all_time_offsets.begin(); + iter != model.all_time_offsets.end(); iter++) { + int32 offset = *iter; + usable_input_t_values.insert(t_out + offset); + } + } + + // add to output_indexes + for (size_t j = 0; j < chosen_output_t_values.size(); j++) { + int32 t_out = chosen_output_t_values[j]; + Index index; + index.n = n_x_pairs[i].first; + index.x = n_x_pairs[i].second; + index.t = t_out; + output_indexes->push_back(index); + } + + std::vector chosen_input_t_values(required_input_t_values.begin(), + required_input_t_values.end()); + for (std::set::const_iterator iter = usable_input_t_values.begin(); + iter != usable_input_t_values.end(); ++iter) { + int32 t = *iter; + if (RandInt(0, 1) == 0) + chosen_input_t_values.push_back(t); + } + SortAndUniq(&chosen_input_t_values); + + // add to input_indexes + for (size_t j = 0; j < chosen_input_t_values.size(); j++) { + int32 t_in = chosen_input_t_values[j]; + Index index; + index.n = n_x_pairs[i].first; + index.x = n_x_pairs[i].second; + index.t = t_in; + input_indexes->push_back(index); + } + } +} + + +void UnitTestTimeHeightConvolutionIo() { + for (int32 i = 0; i < 10; i++) { + KALDI_LOG << "iter = " << i; + // Create a ConvolutionModel and test its I/O. + ConvolutionModel conv_model; + GetRandomConvolutionModel(&conv_model); + std::ostringstream os1, os2; + bool binary = (RandInt(0, 1) == 0); + conv_model.Write(os1, binary); + std::istringstream is(os1.str()); + ConvolutionModel conv_model2; + conv_model2.Read(is, binary); + conv_model2.Write(os2, binary); + KALDI_ASSERT(os1.str() == os2.str() && conv_model2.Check()); + } +} + +void TestComputationIo(const ConvolutionComputation &computation) { + std::ostringstream os1, os2; + bool binary = (RandInt(0, 1) == 0); + computation.Write(os1, binary); + std::istringstream is(os1.str()); + ConvolutionComputation computation2; + computation2.Read(is, binary); + computation2.Write(os2, binary); + KALDI_ASSERT(os1.str() == os2.str()); + computation2.Check(); +} + + +// This function exects indexes.size() == matrix->NumRows(); +// it sets to zero any row i of the matrix for which +// indexes[i].t == kNoTime. +void ZeroBlankRows(const std::vector &indexes, + CuMatrix *matrix) { + KALDI_ASSERT(static_cast(indexes.size()) == matrix->NumRows()); + int32 num_rows = matrix->NumRows(); + if (num_rows == 0) return; + Vector mask(num_rows, kUndefined); + mask.Set(1.0); + const Index *indexes_ptr = &(indexes[0]); + BaseFloat *mask_ptr = mask.Data(); + for (int32 r = 0; r < num_rows; r++) { + if (indexes_ptr[r].t == kNoTime) + mask_ptr[r] = 0.0; + } + CuVector cu_mask; + cu_mask.Swap(&mask); + matrix->MulRowsVec(cu_mask); +} + +// This is a 'dumb' implementation of convolution, created to compare +// with ConvolveForward. +void ConvolveForwardSimple( + const ConvolutionModel &model, + const std::vector &input_indexes, + const std::vector &output_indexes, + const CuMatrixBase &input_cu, + const CuMatrixBase ¶ms_cu, + CuMatrixBase *output_cu) { + // these loops will be very slow on GPU, so do it all on CPU. + Matrix input(input_cu), params(params_cu), + output(*output_cu); + std::unordered_map index_to_row; + int32 input_rows = input.NumRows(), + output_rows = output.NumRows(); + for (int32 r_in = 0; r_in < input_rows; r_in++) { + if (input_indexes[r_in].t != kNoTime) { + index_to_row[input_indexes[r_in]] = r_in; + } + } + int32 num_offsets = model.offsets.size(), + num_filters_in = model.num_filters_in, + num_filters_out = model.num_filters_out, + height_in = model.height_in, + height_out = model.height_out, + height_subsample_out = model.height_subsample_out; + for (int32 r_out = 0; r_out < output_rows; r_out++) { + Index index_out = output_indexes[r_out]; + if (index_out.t == kNoTime) + continue; + SubVector output_row(output, r_out); + for (int32 o = 0; o < num_offsets; o++) { + int32 time_offset = model.offsets[o].time_offset, + height_offset = model.offsets[o].height_offset; + Index index_in(index_out); + index_in.t += time_offset; + std::unordered_map::const_iterator iter = + index_to_row.find(index_in); + if (iter != index_to_row.end()) { + SubMatrix params_part(params, 0, params.NumRows(), + o * num_filters_in, num_filters_in); + int32 r_in = iter->second; + SubVector input_row(input, r_in); + for (int32 h_out_subsampled = 0; + h_out_subsampled < height_out; + h_out_subsampled++) { + int32 h_out = h_out_subsampled * height_subsample_out, + h_in = h_out + height_offset; + if (h_in < 0 || h_in >= height_in) + continue; + SubVector output_part(output_row, + h_out_subsampled * num_filters_out, + num_filters_out), + input_part(input_row, h_in * num_filters_in, num_filters_in); + output_part.AddMatVec(1.0, params_part, kNoTrans, input_part, 1.0); + } + } + } + } + output_cu->CopyFromMat(output); +} + + + +void TestRunningComputation(const ConvolutionModel &conv_model, + const std::vector &input_indexes, + const std::vector &output_indexes, + const ConvolutionComputation &computation) { + CuMatrix input(input_indexes.size(), conv_model.InputDim(), + kSetZero, kStrideEqualNumCols), + output(output_indexes.size(), conv_model.OutputDim(), + kSetZero, kStrideEqualNumCols), + output2(output), + params(conv_model.ParamRows(), conv_model.ParamCols()); + input.SetRandn(); + params.SetRandn(); + ZeroBlankRows(input_indexes, &input); + ConvolveForward(computation, input, params, &output); + ZeroBlankRows(output_indexes, &output); + + ConvolveForwardSimple(conv_model, input_indexes, output_indexes, + input, params, &output2); + KALDI_LOG << "Tested convolution for model: " + << conv_model.Info(); + if (!output.ApproxEqual(output2, 0.001)) { + KALDI_LOG << "Output is: " << output; + KALDI_LOG << "Output2 is: " << output2; + KALDI_ERR << "Convolution test failure."; + } +} + + +void TestDataBackprop(const ConvolutionModel &conv_model, + const std::vector &input_indexes, + const std::vector &output_indexes, + const ConvolutionComputation &computation) { + CuMatrix + input_deriv(input_indexes.size(), conv_model.InputDim(), + kSetZero, kStrideEqualNumCols), + input(input_indexes.size(), conv_model.InputDim(), + kSetZero, kStrideEqualNumCols), + output(output_indexes.size(), conv_model.OutputDim(), + kSetZero, kStrideEqualNumCols), + output_deriv(output_indexes.size(), conv_model.OutputDim(), + kSetZero, kStrideEqualNumCols), + params(conv_model.ParamRows(), conv_model.ParamCols()); + + input.SetRandn(); + params.SetRandn(); + output_deriv.SetRandn(); + + ZeroBlankRows(output_indexes, &output_deriv); + ConvolveBackwardData(computation, params, output_deriv, &input_deriv); + ZeroBlankRows(input_indexes, &input_deriv); + ZeroBlankRows(input_indexes, &input); + + // define the objf as TraceMatMat(output_deriv, output, kTrans). + // we can work it out from the backpropagated data-derivative. + BaseFloat expected_objf = TraceMatMat(input_deriv, input, kTrans); + + ConvolveForward(computation, input, params, &output); + ZeroBlankRows(output_indexes, &output); + + BaseFloat observed_objf = TraceMatMat(output, output_deriv, kTrans); + + KALDI_LOG << "Expected objf = " << expected_objf + << ", observed objf = " << observed_objf; + if (!ApproxEqual(expected_objf, observed_objf, 0.1) && + fabs(expected_objf) < 1.0) { + KALDI_ERR << "Difference in objf too large."; + } +} + + +void TestParamsBackprop(const ConvolutionModel &conv_model, + const std::vector &input_indexes, + const std::vector &output_indexes, + const ConvolutionComputation &computation) { + CuMatrix + input(input_indexes.size(), conv_model.InputDim(), + kSetZero, kStrideEqualNumCols), + output(output_indexes.size(), conv_model.OutputDim(), + kSetZero, kStrideEqualNumCols), + output_deriv(output_indexes.size(), conv_model.OutputDim(), + kSetZero, kStrideEqualNumCols), + params(conv_model.ParamRows(), conv_model.ParamCols()), + params_deriv(conv_model.ParamRows(), conv_model.ParamCols()); + + input.SetRandn(); + params.SetRandn(); + output_deriv.SetRandn(); + + BaseFloat alpha = 0.5 * RandInt(1, 3); + + ZeroBlankRows(output_indexes, &output_deriv); + ZeroBlankRows(input_indexes, &input); + + ConvolveBackwardParams(computation, input, output_deriv, alpha, + ¶ms_deriv); + + BaseFloat expected_objf = TraceMatMat(params_deriv, params, kTrans) / alpha; + + ConvolveForward(computation, input, params, &output); + + ZeroBlankRows(output_indexes, &output); + + BaseFloat observed_objf = TraceMatMat(output, output_deriv, kTrans); + + KALDI_LOG << "Expected objf = " << expected_objf + << ", observed objf = " << observed_objf; + if (!ApproxEqual(expected_objf, observed_objf, 0.1) && + fabs(expected_objf) < 1.0) { + KALDI_ERR << "Difference in objf too large."; + } +} + + + +void UnitTestTimeHeightConvolutionCompile() { + for (int32 i = 0; i < 10; i++) { + KALDI_LOG << "iter = " << i; + // Create a ConvolutionModel + ConvolutionModel conv_model; + GetRandomConvolutionModel(&conv_model); + std::vector input_indexes, output_indexes; + GetRandomConvolutionIndexes(conv_model, &input_indexes, &output_indexes); + + ConvolutionComputationOptions opts; + ConvolutionComputation computation; + std::vector input_indexes_modified, output_indexes_modified; + CompileConvolutionComputation(conv_model, input_indexes, output_indexes, + opts, &computation, + &input_indexes_modified, + &output_indexes_modified); + TestComputationIo(computation); + TestRunningComputation(conv_model, + input_indexes_modified, + output_indexes_modified, + computation); + TestDataBackprop(conv_model, + input_indexes_modified, + output_indexes_modified, + computation); + TestParamsBackprop(conv_model, + input_indexes_modified, + output_indexes_modified, + computation); + std::ostringstream os; + os << "\nInput-indexes: "; + WriteIndexVector(os, false, input_indexes); + os << "\nInput-indexes-modified: "; + WriteIndexVector(os, false, input_indexes_modified); + os << "\nOutput-indexes: "; + WriteIndexVector(os, false, output_indexes); + os << "\nOutput-indexes-modified: "; + WriteIndexVector(os, false, output_indexes_modified); + KALDI_LOG << os.str(); + } +} + + +void UnitTestTimeHeightConvolution() { + UnitTestTimeHeightConvolutionIo(); + UnitTestTimeHeightConvolutionCompile(); +} + +*/ + +} // namespace cudnn_convolution +} // namespace nnet3 +} // namespace kaldi + + +int main() { + using namespace kaldi; + using namespace kaldi::nnet3; + using namespace kaldi::nnet3::cudnn_convolution; + + + for (int32 loop = 0; loop < 2; loop++) { +#if HAVE_CUDA == 1 + CuDevice::Instantiate().SetDebugStrideMode(true); + if (loop == 0) + CuDevice::Instantiate().SelectGpuId("no"); // -1 means no GPU + else + CuDevice::Instantiate().SelectGpuId("optional"); // -2 .. automatic selection +#endif + TestConvolutionComputationConfig(); + TestConvolutionComputation(); + } +} diff --git a/src/nnet3/convolution-cudnn.cc b/src/nnet3/convolution-cudnn.cc index 4c5a69d3532..9f406d4868e 100644 --- a/src/nnet3/convolution-cudnn.cc +++ b/src/nnet3/convolution-cudnn.cc @@ -32,40 +32,127 @@ const BaseFloat ONE(1.0); const BaseFloat ZERO(0.0); } -ConvolutionComputation:: -ConvolutionComputation(int32 num_channels_out, int32 num_channels_in, - int32 filter_height, int32 filter_width, - int32 filter_stride_vertical, int32 filter_stride_horizontal, - int32 filter_dilation_height, - int32 filter_dilation_width, - int32 num_images, - int32 input_image_height, int32 input_image_width, - int32 zero_padding_height, int32 zero_padding_width): - num_channels_out_(num_channels_out), - num_channels_in_(num_channels_in), - filter_height_(filter_height), - filter_width_(filter_width), - filter_stride_vertical_(filter_stride_vertical), - filter_stride_horizontal_(filter_stride_horizontal), - filter_dilation_height_(filter_dilation_height), - filter_dilation_width_(filter_dilation_width), - num_images_(num_images), - input_image_height_(input_image_height), - input_image_width_(input_image_width), - zero_padding_height_(zero_padding_height), - zero_padding_width_(zero_padding_width) { + +void ConvolutionComputationConfig::Check() { + KALDI_ASSERT(num_images > 0 && num_channels_out > 0 && + num_channels_in > 0 && filter_height > 0 && filter_width > 0); + KALDI_ASSERT(filter_stride_vertical > 0 && filter_stride_horizontal > 0 && + filter_dilation_vertical > 0 && filter_dilation_horizontal > 0); + KALDI_ASSERT(input_image_height > 0 && input_image_width > 0 && + zero_padding_vertical >= 0 && zero_padding_horizontal >= 0); +} + +void ConvolutionComputationConfig::ComputeOutputImageSize() { + { // This blocks deals with the vertical direction. + + // 'filter_height_reduction' is the amount by which the height of the filter patch + // reduces the effective height of the input image. It's the distance between + // the first and last pixels of the filter patch. E.g. in a 3x3 kernel it + // would be 2. + int32 filter_height_reduction = (filter_height - 1) * filter_dilation_vertical; + // 'modified_input_height' is the number of times we can shift the filter patch + // (not yet taking account of any filter stride). It's a kind of augmented input-image + // height, after applying zero-padding and subtracting filter_height_reduction. + int32 modified_input_height = + input_image_height - filter_height_reduction + (zero_padding_vertical * 2), + s = filter_stride_vertical; + + // output_image_height equals reduced_input_height divided by s (but rounding + // up), which is the number of times we can shift the filter patch by + // filter_stride_vertical_. + output_image_height = (modified_input_height + s - 1) / s; + } + + { // This blocks deals with the horizontal direction. + + // 'filter_width_reduction' is the amount by which the width of the filter patch + // reduces the effective width of the input image. It's the distance between + // the first and last pixels of the filter patch. E.g. in a 3x3 kernel it + // would be 2. + int32 filter_width_reduction = (filter_width - 1) * filter_dilation_horizontal; + // 'modified_input_width' is the number of times we can shift the filter patch + // (not yet taking account of any filter stride). It's a kind of augmented input-image + // width, after applying zero-padding and subtracting filter_width_reduction. + int32 modified_input_width = + input_image_width - filter_width_reduction + (zero_padding_horizontal * 2), + s = filter_stride_horizontal; + + // output_image_width equals reduced_input_width divided by s (but rounding + // up), which is the number of times we can shift the filter patch by + // filter_stride_horizontal_. + output_image_width = (modified_input_width + s - 1) / s; + } +} + +void ConvolutionComputationConfig::Write(std::ostream &os, bool binary) const { + WriteToken(os, binary, ""); + WriteBasicType(os, binary, num_images); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, num_channels_in); + WriteBasicType(os, binary, num_channels_out); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, filter_height); + WriteBasicType(os, binary, filter_width); + WriteBasicType(os, binary, filter_stride_vertical); + WriteBasicType(os, binary, filter_stride_horizontal); + WriteBasicType(os, binary, filter_dilation_vertical); + WriteBasicType(os, binary, filter_dilation_horizontal); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, input_image_height); + WriteBasicType(os, binary, input_image_width); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, zero_padding_vertical); + WriteBasicType(os, binary, zero_padding_horizontal); + WriteToken(os, binary, ""); +} + +void ConvolutionComputationConfig::Read(std::istream &is, bool binary) { + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &num_images); + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &num_channels_in); + ReadBasicType(is, binary, &num_channels_out); + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &filter_height); + ReadBasicType(is, binary, &filter_width); + ReadBasicType(is, binary, &filter_stride_vertical); + ReadBasicType(is, binary, &filter_stride_horizontal); + ReadBasicType(is, binary, &filter_dilation_vertical); + ReadBasicType(is, binary, &filter_dilation_horizontal); + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &input_image_height); + ReadBasicType(is, binary, &input_image_width); + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &zero_padding_vertical); + ReadBasicType(is, binary, &zero_padding_horizontal); + ExpectToken(is, binary, ""); +} + + +ConvolutionComputation::ConvolutionComputation( + const ConvolutionComputationConfig &config): config_(config) { + config_.Check(); + config_.ComputeOutputImageSize(); #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { InitCudnn(); } #endif - // The following is called whether or not we are using CUDA. - ComputeOutputImageHeight(); - ComputeOutputImageWidth(); } +ConvolutionComputation::ConvolutionComputation() { +#if HAVE_CUDA == 1 + descriptors_initialized_ = false; +#endif +} + + #if HAVE_CUDA == 1 void ConvolutionComputation::InitCudnn() { + descriptors_initialized_ = true; + + const ConvolutionComputationConfig &c = config_; + CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&input_desc_)); CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&output_desc_)); CUDNN_SAFE_CALL(cudnnCreateFilterDescriptor(¶ms_desc_)); @@ -75,34 +162,50 @@ void ConvolutionComputation::InitCudnn() { CUDNN_SAFE_CALL( cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NHWC, - CUDNN_DATA_FLOAT, num_images_, - num_channels_in_, input_image_width_, - input_image_height_)); + CUDNN_DATA_FLOAT, c.num_images, + c.num_channels_in, c.input_image_width, + c.input_image_height)); CUDNN_SAFE_CALL( - cudnnSetConvolution2dDescriptor(conv_desc_, - zero_padding_width_, zero_padding_height_, - filter_stride_horizontal_, filter_stride_vertical_, - filter_dilation_width_, filter_dilation_height_, - CUDNN_CROSS_CORRELATION, // TODO: Double check this! - CUDNN_DATA_FLOAT)); + cudnnSetConvolution2dDescriptor( + conv_desc_, + c.zero_padding_horizontal, c.zero_padding_vertical, + c.filter_stride_horizontal, c.filter_stride_vertical, + c.filter_dilation_horizontal, c.filter_dilation_vertical, + CUDNN_CROSS_CORRELATION, // TODO: Double check this! + CUDNN_DATA_FLOAT)); CUDNN_SAFE_CALL( cudnnSetFilter4dDescriptor(params_desc_, CUDNN_DATA_FLOAT, - CUDNN_TENSOR_NCHW, num_channels_out_, - num_channels_in_, filter_width_, filter_height_)); + CUDNN_TENSOR_NCHW, c.num_channels_out, + c.num_channels_in, c.filter_width, + c.filter_height)); + + int32 kaldi_height_cudnn_width, kaldi_width_cudnn_height, unused; + CUDNN_SAFE_CALL( + cudnnGetConvolution2dForwardOutputDim(conv_desc_, input_desc_, + params_desc_, + &unused, &unused, + &kaldi_width_cudnn_height, + &kaldi_height_cudnn_width)); + + if (kaldi_height_cudnn_width != c.output_image_height) + KALDI_ERR << "Code error: the height from CUDNN " << kaldi_height_cudnn_width + << " does not match our value " << c.output_image_height; + if (kaldi_width_cudnn_height != c.output_image_width) + KALDI_ERR << "Code error: the width from CUDNN " << kaldi_width_cudnn_height + << " does not match our value " << c.output_image_width; + // These two member functions depend only on input_desc_, // conv_desc_, and params_desc_, so they are safe to call now. - int32 out_kaldi_height_cudnn_width = OutputImageHeight(); - int32 out_kaldi_width_cudnn_height = OutputImageWidth(); CUDNN_SAFE_CALL( cudnnSetTensor4dDescriptor(output_desc_, CUDNN_TENSOR_NHWC, - CUDNN_DATA_FLOAT, num_images_, - num_channels_in_, out_kaldi_width_cudnn_height, - out_kaldi_height_cudnn_width)); + CUDNN_DATA_FLOAT, c.num_images, + c.num_channels_in, kaldi_width_cudnn_height, + kaldi_height_cudnn_width)); const int32 bias_stride[] = {1}; CUDNN_SAFE_CALL( cudnnSetTensorNdDescriptor(bias_desc_, CUDNN_DATA_FLOAT, 1, - &num_channels_out_, bias_stride)); + &c.num_channels_out, bias_stride)); const double DONT_CARE = 0; CUDNN_SAFE_CALL( @@ -218,97 +321,19 @@ void ConvolutionComputation::DestroyCudnn() { ConvolutionComputation::~ConvolutionComputation() { #if HAVE_CUDA == 1 - if (CuDevice::Instantiate().Enabled()) + if (CuDevice::Instantiate().Enabled() && descriptors_initialized_) DestroyCudnn(); #endif } -void ConvolutionComputation::ComputeOutputImageHeight() { - // 'filter_height_reduction' is the amount by which the height of the filter patch - // reduces the effective height of the input image. It's the distance between - // the first and last pixels of the filter patch. E.g. in a 3x3 kernel it - // would be 2. - int32 filter_height_reduction = (filter_height_ - 1) * filter_dilation_height_; - // 'modified_input_height' is the number of times we can shift the filter patch - // (not yet taking account of any filter stride). It's a kind of augmented input-image - // height, after applying zero-padding and subtracting filter_height_reduction. - int32 modified_input_height = - input_image_height_ - filter_height_reduction + (zero_padding_height_ * 2), - s = filter_stride_vertical_; - - // output_image_height_ equals reduced_input_height divided by s (but rounding - // up), which is the number of times we can shift the filter patch by - // filter_stride_vertical_. - output_image_height_ = (modified_input_height + s - 1) / s; - -#if HAVE_CUDA == 1 - // Check that CUDA has the same idea of what the output image height is, as we - // do. This helps check that the CPU and GPU computations are compatible. - int32 unused; - int32 kaldi_height_cudnn_width; - CUDNN_SAFE_CALL( - cudnnGetConvolution2dForwardOutputDim(conv_desc_, input_desc_, - params_desc_, - &unused, &unused, - &unused, - &kaldi_height_cudnn_width)); - if (kaldi_height_cudnn_width != output_image_height_) { - KALDI_ERR << "Code error: the height from CUDNN " << kaldi_height_cudnn_width - << " does not match our value " << output_image_height_; - } -#endif -} - -void ConvolutionComputation::ComputeOutputImageWidth() { - // 'filter_width_reduction' is the amount by which the width of the filter patch - // reduces the effective width of the input image. It's the distance between - // the first and last pixels of the filter patch. E.g. in a 3x3 kernel it - // would be 2. - int32 filter_width_reduction = (filter_width_ - 1) * filter_dilation_width_; - // 'modified_input_width' is the number of times we can shift the filter patch - // (not yet taking account of any filter stride). It's a kind of augmented input-image - // width, after applying zero-padding and subtracting filter_width_reduction. - int32 modified_input_width = - input_image_width_ - filter_width_reduction + (zero_padding_width_ * 2), - s = filter_stride_horizontal_; - - // output_image_width equals reduced_input_width divided by s (but rounding - // up), which is the number of times we can shift the filter patch by - // filter_stride_horizontal_. - output_image_width_ = (modified_input_width + s - 1) / s; -#if HAVE_CUDA == 1 - int32 unused; - int32 kaldi_width_cudnn_height; - CUDNN_SAFE_CALL( - cudnnGetConvolution2dForwardOutputDim(conv_desc_, input_desc_, - params_desc_, - &unused, &unused, - &kaldi_width_cudnn_height, - &unused)); - if (kaldi_width_cudnn_height != output_image_width_) { - KALDI_ERR << "Code error: the height from CUDNN " << kaldi_width_cudnn_height - << " does not match our value " << output_image_width_; - } -#endif -} - - -void ConvolutionComputation::Write(std::ostream &os, bool binary) const { - // TODO: write just num_channels_out_ through zero_padding_width_; - -} void ConvolutionComputation::Read(std::istream &is, bool binary) { - // TODO: read just num_channels_out_ through zero_padding_width_; - + config_.Read(is, binary); #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { InitCudnn(); } #endif - // The following are called whether or not we have CUDA. - ComputeOutputImageHeight(); - ComputeOutputImageWidth(); } @@ -317,17 +342,18 @@ ConvolveForward(const CuMatrixBase &input, const CuMatrixBase ¶ms, const CuVectorBase &bias, CuMatrixBase *output) const { + const ConvolutionComputationConfig &c = config_; // Check some dimensions. KALDI_ASSERT( - input.NumRows() == num_images_ * input_image_width_ && - input.NumCols() == input_image_height_ * num_channels_in_ && + input.NumRows() == c.num_images * c.input_image_width && + input.NumCols() == c.input_image_height * c.num_channels_in && input.Stride() == input.NumCols() && - params.NumRows() == num_channels_out_ && - params.NumCols() == num_channels_in_ * filter_height_ * filter_width_ && + params.NumRows() == c.num_channels_out && + params.NumCols() == c.num_channels_in * c.filter_height * c.filter_width && params.Stride() == params.NumCols() && - bias.Dim() == num_channels_out_ && - output->NumRows() == num_images_ * input_image_height_ && - output->NumCols() == input_image_width_ * num_channels_out_ && + bias.Dim() == c.num_channels_out && + output->NumRows() == c.num_images * c.input_image_height && + output->NumCols() == c.input_image_width * c.num_channels_out && output->Stride() == output->NumCols()); #ifdef HAVE_CUDNN @@ -367,69 +393,70 @@ ConvolveForward(const MatrixBase &input, const MatrixBase ¶ms, const VectorBase &bias, MatrixBase *output) const { + const ConvolutionComputationConfig &c = config_; // Check some dimensions. KALDI_ASSERT( - input.NumRows() == num_images_ * input_image_width_ && - input.NumCols() == input_image_height_ * num_channels_in_ && + input.NumRows() == c.num_images * c.input_image_width && + input.NumCols() == c.input_image_height * c.num_channels_in && input.Stride() == input.NumCols() && - params.NumRows() == num_channels_out_ && - params.NumCols() == num_channels_in_ * filter_height_ * filter_width_ && + params.NumRows() == c.num_channels_out && + params.NumCols() == c.num_channels_in * c.filter_height * c.filter_width && params.Stride() == params.NumCols() && - bias.Dim() == num_channels_out_ && - output->NumRows() == num_images_ * input_image_height_ && - output->NumCols() == input_image_width_ * num_channels_out_ && + bias.Dim() == c.num_channels_out && + output->NumRows() == c.num_images * c.input_image_height && + output->NumCols() == c.input_image_width * c.num_channels_out && output->Stride() == output->NumCols()); { // Deal with the bias. SubMatrix output_rearranged( output->Data(), - num_images_ * input_image_width_ * input_image_height_, - num_channels_out_, num_channels_out_); + c.num_images * c.input_image_width * c.input_image_height, + c.num_channels_out, c.num_channels_out); output_rearranged.CopyRowsFromVec(bias); } - Matrix params_rearranged(filter_width_ * filter_height_, - num_channels_out_ * num_channels_in_, + Matrix params_rearranged(c.filter_width * c.filter_height, + c.num_channels_out * c.num_channels_in, kUndefined, kStrideEqualNumCols); ConvertParams(params, ¶ms_rearranged); // We're using variable names w (as in width) for horizontal positions and h // (as in height) for vertical positions. This is perhaps not ideal. - for (int32 output_w = 0; output_w < output_image_width_; output_w++) { - for (int32 output_h = 0; output_h < output_image_height_; output_h++) { - for (int32 filter_h = 0; filter_h < filter_height_; filter_h++) { - int32 filter_h_flipped = filter_height_ - 1 - filter_h; - int32 input_h = output_h * filter_stride_vertical_ - - zero_padding_height_ - + filter_h * filter_dilation_height_; - if (input_h < 0 || input_h >= input_image_height_) + for (int32 output_w = 0; output_w < c.output_image_width; output_w++) { + for (int32 output_h = 0; output_h < c.output_image_height; output_h++) { + for (int32 filter_h = 0; filter_h < c.filter_height; filter_h++) { + int32 filter_h_flipped = c.filter_height - 1 - filter_h; + int32 input_h = output_h * c.filter_stride_vertical + - c.zero_padding_vertical + + filter_h * c.filter_dilation_vertical; + if (input_h < 0 || input_h >= c.input_image_height) continue; - for (int32 filter_w = 0; filter_w < filter_width_; filter_w++) { - int32 filter_w_flipped = filter_width_ - 1 - filter_w; - int32 input_w = output_w * filter_stride_horizontal_ - - zero_padding_width_ - + filter_w * filter_dilation_width_; + for (int32 filter_w = 0; filter_w < c.filter_width; filter_w++) { + int32 filter_w_flipped = c.filter_width - 1 - filter_w; + int32 input_w = output_w * c.filter_stride_horizontal + - c.zero_padding_horizontal + + filter_w * c.filter_dilation_horizontal; - if (input_w < 0 || input_w >= input_image_width_) + if (input_w < 0 || input_w >= c.input_image_width) continue; const BaseFloat *params_data = params_rearranged.RowData( - filter_w_flipped * filter_height_ + filter_h_flipped); + filter_w_flipped * c.filter_height + filter_h_flipped); SubMatrix this_params(params_data, - num_channels_out_, - num_channels_in_, num_channels_in_); + c.num_channels_out, + c.num_channels_in, c.num_channels_in); const BaseFloat *input_data = input.Data() + - input_w * input_image_height_ * num_channels_in_ + - input_h * num_channels_in_; + input_w * c.input_image_height * c.num_channels_in + + input_h * c.num_channels_in; SubMatrix this_input_pixel(input_data, - num_images_, - num_channels_in_, - num_channels_in_); + c.num_images, + c.num_channels_in, + c.num_channels_in); SubMatrix this_output_pixel(input_data, - num_images_, - num_channels_in_, - num_channels_in_); + c.num_images, + c.num_channels_in, + c.num_channels_in); this_output_pixel.AddMatMat(1.0, this_input_pixel, kNoTrans, this_params, kTrans, 1.0); } @@ -552,16 +579,17 @@ ConvolveBackwardBias(const MatrixBase &output_deriv, void ConvolutionComputation::ConvertParams( const MatrixBase ¶ms, MatrixBase *params_rearranged) const { - KALDI_ASSERT(params.NumRows() == num_channels_out_ && + const ConvolutionComputationConfig &c = config_; + KALDI_ASSERT(params.NumRows() == c.num_channels_out && params.Stride() == params.NumCols() && - params_rearranged->NumRows() == filter_width_ * filter_height_ && + params_rearranged->NumRows() == c.filter_width * c.filter_height && params_rearranged->Stride() == params_rearranged->NumCols()); // Reinterpret params as params_reinterpret which is of dimension KC * WH (instead of K * CWH). SubMatrix params_reinterpret(params.Data(), - num_channels_out_ * num_channels_in_, - filter_width_ * filter_height_, - filter_width_ * filter_height_); + c.num_channels_out * c.num_channels_in, + c.filter_width * c.filter_height, + c.filter_width * c.filter_height); params_rearranged->CopyFromMat(params_reinterpret, kTrans); } diff --git a/src/nnet3/convolution-cudnn.h b/src/nnet3/convolution-cudnn.h index 828150af035..efacab0bfa1 100644 --- a/src/nnet3/convolution-cudnn.h +++ b/src/nnet3/convolution-cudnn.h @@ -34,94 +34,122 @@ namespace kaldi { namespace nnet3 { namespace cudnn_convolution { +/** This struct contains information about a specific convolution + computation. It combines information about the model and the data. + The examples below are very arbitrary. + */ +struct ConvolutionComputationConfig { + // The number of images we are working on, e.g. 128. + int32 num_images; + // The number of input channels, e.g. 32 + int32 num_channels_in; + // The number of output channels, e.g. 64. + int32 num_channels_out; + // The number of pixels in the filter patch in the vertical direction, e.g. 3. + int32 filter_height; + // The number of pixels in the filter patch in the horizontal direction, e.g. 3. + int32 filter_width; + // The vertical stride of the filter, normally 1 but might be (e.g.) 2 if we + // are subsampling in the vertical (usually: frequency) dimension. + int32 filter_stride_vertical; + // The horizontal stride of the filter, normally 1 but might be (e.g.) 3 at a + // certain layer of the network if we are training a chain model with a + // frame-subsampling-factor of 3. + int32 filter_stride_horizontal; + // Normally 1, if this is more than 1 the pixels of the image patch will be + // spaced apart from each other. + int32 filter_dilation_vertical; + // Normally 1, if this is more than 1 the pixels of the image patch will be + // spaced apart from each other. + int32 filter_dilation_horizontal; + // The height of the input image, e.g. this is often 40 in speech applications, + // subsampled to 20 or 10 later in the network. + int32 input_image_height; + // The width of the input image, which will be the same as the number of + // frames being computed. + int32 input_image_width; + // The amount of zero-padding in the height (normally: frequency) dimension; + // this number of zero frames are added at both top and bottom of the input. + // Will often be 1, if you are using 3x3 kernels and don't want to + // reduce the height of the image. + int32 zero_padding_vertical; + // The amount of zero-padding in the time dimension, meaning the number of + // zero frames that we implicitly add to the beginning and end of the utterance. + // This will normally be zero because in Kaldi ASR recipes we generally don't + // do zero padding, but duplicate the first and last frame of the input to + // match the amount of left and right context that the neural network requires. + int32 zero_padding_horizontal; + + + // The height of the output image. The user does not have to set this; + // it will be computed when you call ComputeOutputImageSize(). + int32 output_image_height; + // The width of the output image. The user does not have to set this; + // it will be computed when you call ComputeOutputImageSize(). + int32 output_image_width; + + + // Checks that all the configuration variables except output_image_height + // and output_image_width have allowed values. + void Check(); + + // Computes output_image_height and output_image_width from the other + // configuration values. + void ComputeOutputImageSize(); + + void Write(std::ostream &os, bool binary) const; + + void Read(std::istream &is, bool binary); + +}; + + /** - Represents structural information about a convolution computation, with - filters, padding, striding, inputs and outputs of a specified size. The - same interface is usable on both GPU and CPU. You create this object only - after you know the number of images and input and output sizes, and it will - be stored as part of a NnetComputation (i.e. a compiled computation) and - re-used between different minibatches. This object is lightweight; it - doesn't contain data, only a few integers and descriptors. - - In the following docstrings: - N is equivalent to num_images - C is equivalent to num_channels_in - K is equivalent to num_channels_out - H is equivalent to input_image_height or output_image_height (for images) or - filter_height (for filter parameters). - W is equivalent to input_image_width or output_image_width (for images) or - filter_width (for filter parameters). - - @param [in] num_channels_out Number of output channels, e.g. 64. - @param [in] num_channels_in Number of input channels, e.g. 32. - @param [in] filter_height Height of filter patch, e.g. 3 (for 3x3 kernel). Corresponds - to the 'frequency' dimension in normal speech applications, or - height in OCR applications. - @param [in] filter_width Width of filter patch, e.g. 3 (for 3x3 kernel). Corresponds - to the 'time' dimension in normal speech applications. - @param [in] filter_stride_vertical Filter stride in the vertical ('frequency') dimension. - Will normally be 1 in speech and OCR applications. - @param [in] filter_stride_horizontal Filter stride in the horizontal ('time') dimension. - Will usually be 1 in most layers, but may be 2 or 3 if - we are doing subsampling on this layer (e.g. in - reduced-frame-rate models like chain models). - @param [in] filter_dilation_height Filter dilation in the vertical ('frequency') - dimension. Equals the stride, in the input image, of - individual elements of the filter patch. Will - normally be 1. - @param [in] filter_dilation_width Filter dilation in the horizontal ('time') - dimension. Will normally be 1, but could - be more than one if, for instance, you have components - with time-stride > 1 which for some reason are required - to be evaluated on every frame. - @param [in] num_images The number of images we are processing, generally - equal to the minibatch size. - @param [in] input_image_height The height of the input images. Corresponds to - the number of frequency bins, in speech applications. - @param [in] input_image_width The width of the input images. Corresponds to - the number of time frames on the input, in speech - applications. - @param [in] zero_padding_height The number of pixels that we zero-pad with on - the bottom, and on the top, of the image (the - frequency dimension, in speech applications). Would - be 1, for instance, if you are using a 3x3 kernel - and don't want to lose frequency bins. - @param [in] zero_padding_width The number of frames that we zero-pad with on - the left, and on the right, of the image (time - dimension). Likely to be 0 in many speech applications, - since we normally deal with edge effects by padding - with repeats of the first and last frame; but - padding is supported by this object. + This object allows you to execute a convolution computation, and its backprop, + on either a GPU (using CUDNN), or on CPU using a compatible interface. + + This object is quite lightweight: it only contains some structural data and a + few smallish CUDNN descriptors that are derived from it. + */ class ConvolutionComputation final { public: - ConvolutionComputation(int32 num_channels_out, int32 num_channels_in, - int32 filter_height, int32 filter_width, - int32 filter_stride_vertical, int32 filter_stride_horizontal, - int32 filter_dilation_height, - int32 filter_dilation_width, - int32 num_images, - int32 input_image_height, int32 input_image_width, - int32 zero_padding_height, int32 zero_padding_width); + // Note: you don't have to have done ComputeOutputImageSize() on 'config', + // this class will do it in the constructor. + ConvolutionComputation(const ConvolutionComputationConfig &config); + + // This constructor may be used prior to calling Read(). + ConvolutionComputation(); + ~ConvolutionComputation(); - int32 OutputImageHeight() const { return output_image_height_; } - int32 OutputImageWidth() const { return output_image_width_; } + /* + For an explanation of the notation below (e.g. NWHC): - /** - * For an explanation of the notation below (e.g. NWHC), see the - * explanation for those variable names in the documentation for this - * class above. Variables that come first have the higher stride. - * - * Caution: for convenience, given the way nnet3 works, we flip the notion of - * height and width that CUDNN uses, so our height is CUDNN's width, and vice - * versa. This is not visible to the user; we mention it just in case - * those familiar with CUDNN get surprised at the order - * - * @param [in] input NWHC fully-packed tensor, with NumRows() == N * W - * @param [in] params KCWH fully-packed tensor, with NumRows() == K. - * @param [in] bias vector of length K - * @param [out] output Pre-allocated NWHK fully-packed tensor, with N == NumRows() + N is equivalent to num_images + C is equivalent to num_channels_in + K is equivalent to num_channels_out + H is equivalent to input_image_height or output_image_height (for images) or + filter_height (for filter parameters). + W is equivalent to input_image_width or output_image_width (for images) or + filter_width (for filter parameters). + and the order of letters is from highest to lowest stride, e.g in + + NWHC, N would have the highest stride, and C a stride of 1. + + + explanation for those variable names in the documentation for this + class above. Variables that come first have the higher stride. + + Caution: for convenience, given the way nnet3 works, we flip the notion of + height and width that CUDNN uses, so our height is CUDNN's width, and vice + versa. This is not visible to the user; we mention it just in case + those familiar with CUDNN get surprised at the order + + @param [in] input NWHC fully-packed tensor, with NumRows() == N * W + @param [in] params KCWH fully-packed tensor, with NumRows() == K. + @param [in] bias vector of length K + @param [out] output Pre-allocated NWHK fully-packed tensor, with N == NumRows() */ void ConvolveForward(const CuMatrixBase &input, const CuMatrixBase ¶ms, @@ -184,9 +212,9 @@ class ConvolutionComputation final { - void Write(std::ostream &os, bool binary) const; + void Write(std::ostream &os, bool binary) const { config_.Write(os, binary); } - void Read(std::istream &os, bool binary); + void Read(std::istream &is, bool binary); private: #if HAVE_CUDA == 1 @@ -221,27 +249,11 @@ class ConvolutionComputation final { - // The following block of members are just copies of the args to the - // constructor. Please see the documentation of the constructor, and look for - // the similarly named parameter, to understand the meaning of these - // individual members. - int32 num_channels_out_; - int32 num_channels_in_; - int32 filter_height_; - int32 filter_width_; - int32 filter_stride_vertical_; - int32 filter_stride_horizontal_; - int32 filter_dilation_height_; - int32 filter_dilation_width_; - int32 num_images_; - int32 input_image_height_; - int32 input_image_width_; - int32 zero_padding_height_; - int32 zero_padding_width_; - int32 output_image_height_; - int32 output_image_width_; + ConvolutionComputationConfig config_; + #if HAVE_CUDA == 1 + bool descriptors_initialized_; cudnnTensorDescriptor_t input_desc_; cudnnTensorDescriptor_t output_desc_; cudnnFilterDescriptor_t params_desc_; diff --git a/tools/Makefile b/tools/Makefile index cab8245ee2e..1d62e1a3765 100644 --- a/tools/Makefile +++ b/tools/Makefile @@ -18,7 +18,7 @@ ifeq ("$(shell expr $(OPENFST_VER_NUM) \< 10600)","1") Supported versions: >= 1.6.0) endif -all: check_required_programs sph2pipe sclite openfst cudnn +all: check_required_programs sph2pipe sclite openfst @echo -e "\n\n" @echo "Warning: IRSTLM is not installed by default anymore. If you need IRSTLM" @echo "Warning: use the script extras/install_irstlm.sh" @@ -149,9 +149,3 @@ openblas_compiled: cd OpenBLAS; sed 's:# FCOMMON_OPT = -frecursive:FCOMMON_OPT = -frecursive:' < Makefile.rule >tmp && mv tmp Makefile.rule # $(MAKE) PREFIX=`pwd`/OpenBLAS/install FC=gfortran $(fortran_opt) DEBUG=1 USE_THREAD=1 NUM_THREADS=64 -C OpenBLAS all install $(MAKE) PREFIX=`pwd`/OpenBLAS/install FC=gfortran $(fortran_opt) DEBUG=1 USE_THREAD=0 -C OpenBLAS all install - -cudnn: - wget -T 10 -t 3 http://developer.download.nvidia.com/compute/redist/cudnn/v7.1.2/cudnn-9.1-linux-x64-v7.1.tgz -O cudnn-9.1-linux-x64-v7.1.tgz - -echo "c61000ed700bc5a009bc2e135bbdf736c9743212b2174a2fc9018a66cc0979ec cudnn-9.1-linux-x64-v7.1.tgz" | sha256sum -c - -mkdir -p cudnn/ - tar --no-same-owner -xzf cudnn-9.1-linux-x64-v7.1.tgz -C cudnn/ From b5d2022e9e3d4a6e362f48c199923b457399caf6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 23 Oct 2018 00:06:04 -0400 Subject: [PATCH 14/22] [src] Fix some bugs, stuck again. --- src/cudamatrix/cu-device.cc | 4 +- src/matrix/kaldi-matrix.cc | 3 +- src/nnet3/convolution-cudnn-test.cc | 27 ++++++---- src/nnet3/convolution-cudnn.cc | 81 ++++++++++++++++++++++------- src/nnet3/convolution-cudnn.h | 13 ++--- 5 files changed, 89 insertions(+), 39 deletions(-) diff --git a/src/cudamatrix/cu-device.cc b/src/cudamatrix/cu-device.cc index 0a9627c0518..d31cdeb82a5 100644 --- a/src/cudamatrix/cu-device.cc +++ b/src/cudamatrix/cu-device.cc @@ -115,7 +115,6 @@ void CuDevice::Initialize() { // Initialize the cuSPARSE library CUSPARSE_SAFE_CALL(cusparseCreate(&cusparse_handle_)); CUSPARSE_SAFE_CALL(cusparseSetStream(cusparse_handle_, cudaStreamPerThread)); - CUDNN_SAFE_CALL(cudnnCreate(&cudnn_handle_)); CUDNN_SAFE_CALL(cudnnSetStream(cudnn_handle_, cudaStreamPerThread)); } @@ -253,11 +252,10 @@ void CuDevice::FinalizeActiveGpu() { // Initialize the cuSPARSE library CUSPARSE_SAFE_CALL(cusparseCreate(&cusparse_handle_)); CUSPARSE_SAFE_CALL(cusparseSetStream(cusparse_handle_, cudaStreamPerThread)); - CUDNN_SAFE_CALL(cudnnCreate(&cudnn_handle_)); CUDNN_SAFE_CALL(cudnnSetStream(cudnn_handle_, cudaStreamPerThread)); - // Notify the user which GPU is being userd. + // Notify the user which GPU is being used. char name[128]; DeviceGetName(name,128, device_id); diff --git a/src/matrix/kaldi-matrix.cc b/src/matrix/kaldi-matrix.cc index ebdaca77e4f..df3aea3fe2c 100644 --- a/src/matrix/kaldi-matrix.cc +++ b/src/matrix/kaldi-matrix.cc @@ -1840,7 +1840,8 @@ void MatrixBase::Svd(VectorBase *s, MatrixBase *U, MatrixBase< KALDI_ERR << "Error doing Svd (did not converge), first part of matrix is\n" << SubMatrix(*this, 0, std::min((MatrixIndexT)10, num_rows_), 0, std::min((MatrixIndexT)10, num_cols_)) - << ", min and max are: " << Min() << ", " << Max(); + << ", min, max and sum are: " << Min() << ", " << Max() + << ", " << Sum(); } } diff --git a/src/nnet3/convolution-cudnn-test.cc b/src/nnet3/convolution-cudnn-test.cc index 559857abf7b..d00629fb6b1 100644 --- a/src/nnet3/convolution-cudnn-test.cc +++ b/src/nnet3/convolution-cudnn-test.cc @@ -33,25 +33,23 @@ static void GetRandomConfig(ConvolutionComputationConfig *config) { config->filter_height = RandInt(1, 3); config->filter_width = RandInt(1, 3); - // TODO: randomize the following as well. For now we just use - // default values. - config->filter_stride_vertical = 1; - config->filter_stride_horizontal = 1; - config->filter_dilation_vertical = 1; + config->filter_stride_vertical = RandInt(1, 2); + config->filter_stride_horizontal = RandInt(1, 2); + config->filter_dilation_vertical = RandInt(1, 2); config->filter_dilation_horizontal = 1; - config->input_image_height = RandInt(3, 10); - config->input_image_width = RandInt(3, 10); + config->input_image_height = RandInt(10, 20); + config->input_image_width = RandInt(10, 20); - config->zero_padding_vertical = 0; - config->zero_padding_horizontal = 0; + config->zero_padding_vertical = RandInt(0, 1); + config->zero_padding_horizontal = RandInt(0, 1); config->Check(); config->ComputeOutputImageSize(); } void TestConvolutionComputationConfig() { - for (int32 i = 0; i < 10; i++) { + for (int32 i = 0; i < 100; i++) { ConvolutionComputationConfig config; GetRandomConfig(&config); std::ostringstream os; @@ -68,9 +66,16 @@ void TestConvolutionComputationConfig() { } void TestConvolutionComputation() { - for (int32 i = 0; i < 10; i++) { + for (int32 i = 0; i < 100; i++) { ConvolutionComputationConfig config; GetRandomConfig(&config); + + { + std::ostringstream os; + config.Write(os, false); + KALDI_LOG << "Config is: " << os.str(); + } + ConvolutionComputation computation(config); std::ostringstream os; diff --git a/src/nnet3/convolution-cudnn.cc b/src/nnet3/convolution-cudnn.cc index 9f406d4868e..9442d85ebd1 100644 --- a/src/nnet3/convolution-cudnn.cc +++ b/src/nnet3/convolution-cudnn.cc @@ -126,6 +126,7 @@ void ConvolutionComputationConfig::Read(std::istream &is, bool binary) { ReadBasicType(is, binary, &zero_padding_vertical); ReadBasicType(is, binary, &zero_padding_horizontal); ExpectToken(is, binary, ""); + ComputeOutputImageSize(); } @@ -148,6 +149,13 @@ ConvolutionComputation::ConvolutionComputation() { #if HAVE_CUDA == 1 + +#if (KALDI_DOUBLEPRECISION != 0) +#define CUDNN_DATA_BASEFLOAT CUDNN_DATA_DOUBLE +#else +#define CUDNN_DATA_BASEFLOAT CUDNN_DATA_FLOAT +#endif + void ConvolutionComputation::InitCudnn() { descriptors_initialized_ = true; @@ -160,11 +168,16 @@ void ConvolutionComputation::InitCudnn() { CUDNN_SAFE_CALL(cudnnCreateConvolutionDescriptor(&conv_desc_)); CUDNN_SAFE_CALL(cudnnCreateActivationDescriptor(&activation_desc_)); + // Caution: in the following call, the 'height' and 'width' are swapped + // relative to what the CUDNN interface specifies; this is because Kaldi's + // notion of what is height vs. width is opposite to CUDNN's. (There + // are good reasons for this). CUDNN_SAFE_CALL( cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NHWC, - CUDNN_DATA_FLOAT, c.num_images, + CUDNN_DATA_BASEFLOAT, c.num_images, c.num_channels_in, c.input_image_width, c.input_image_height)); + // Again: width and height are swapped. CUDNN_SAFE_CALL( cudnnSetConvolution2dDescriptor( conv_desc_, @@ -172,14 +185,24 @@ void ConvolutionComputation::InitCudnn() { c.filter_stride_horizontal, c.filter_stride_vertical, c.filter_dilation_horizontal, c.filter_dilation_vertical, CUDNN_CROSS_CORRELATION, // TODO: Double check this! - CUDNN_DATA_FLOAT)); + CUDNN_DATA_BASEFLOAT)); + + // Set dimensions of the filters (linear parameters). + // Again: width and height are swapped. Per the CUDNN documentation at + // https://docs.nvidia.com/deeplearning/sdk/pdf/cuDNN-Developer-Guide.pdf for + // cudnnSetFilter4dDescriptor, setting CUDNN_TENSOR_NHWC as the layout + // corresponds to KSRC, meaning: num-channels-out, height, width, num-channels-in, + // where 'height' and 'width' are the filter height and width respectively (e.g. 3 + // and 3 for a 3x3 patch); and these are swapped w.r.t. Kaldi's notion of height and + // width, so as far as Kaldi is concerned, the strides are, from largest to + // smallest: num-channels-out, width, height, num-channels-in. CUDNN_SAFE_CALL( - cudnnSetFilter4dDescriptor(params_desc_, CUDNN_DATA_FLOAT, - CUDNN_TENSOR_NCHW, c.num_channels_out, + cudnnSetFilter4dDescriptor(params_desc_, CUDNN_DATA_BASEFLOAT, + CUDNN_TENSOR_NHWC, c.num_channels_out, c.num_channels_in, c.filter_width, c.filter_height)); - int32 kaldi_height_cudnn_width, kaldi_width_cudnn_height, unused; + int32 kaldi_width_cudnn_height, kaldi_height_cudnn_width, unused; CUDNN_SAFE_CALL( cudnnGetConvolution2dForwardOutputDim(conv_desc_, input_desc_, params_desc_, @@ -194,18 +217,21 @@ void ConvolutionComputation::InitCudnn() { KALDI_ERR << "Code error: the width from CUDNN " << kaldi_width_cudnn_height << " does not match our value " << c.output_image_width; - // These two member functions depend only on input_desc_, // conv_desc_, and params_desc_, so they are safe to call now. CUDNN_SAFE_CALL( cudnnSetTensor4dDescriptor(output_desc_, CUDNN_TENSOR_NHWC, - CUDNN_DATA_FLOAT, c.num_images, - c.num_channels_in, kaldi_width_cudnn_height, + CUDNN_DATA_BASEFLOAT, c.num_images, + c.num_channels_out, kaldi_width_cudnn_height, kaldi_height_cudnn_width)); - const int32 bias_stride[] = {1}; + + // We pad the bias with leading dims of 1, since CUDNN's tensors appear to + // need a dimension of at least 3. + int bias_dims[3] = {1, 1, c.num_channels_out}; + int bias_stride[3] = {c.num_channels_out, c.num_channels_out, 1}; CUDNN_SAFE_CALL( - cudnnSetTensorNdDescriptor(bias_desc_, CUDNN_DATA_FLOAT, 1, - &c.num_channels_out, bias_stride)); + cudnnSetTensorNdDescriptor(bias_desc_, CUDNN_DATA_BASEFLOAT, 3, + bias_dims, bias_stride)); const double DONT_CARE = 0; CUDNN_SAFE_CALL( @@ -231,6 +257,7 @@ void ConvolutionComputation::InitCudnn() { KALDI_ASSERT(returned_algo_count > 0 && "No algorithms were returned by CUDNN."); const cudnnConvolutionFwdAlgoPerf_t& best_forward = forward_results[0]; + KALDI_ASSERT(best_forward.status == CUDNN_STATUS_SUCCESS); fwd_algo_ = best_forward.algo; delete [] forward_results; @@ -251,6 +278,7 @@ void ConvolutionComputation::InitCudnn() { "No algorithms were returned by CUDNN."); const cudnnConvolutionBwdFilterAlgoPerf_t& best_backward_filter = backward_filter_results[0]; + KALDI_ASSERT(best_backward_filter.status == CUDNN_STATUS_SUCCESS); bwd_filter_algo_ = best_backward_filter.algo; delete [] backward_filter_results; @@ -271,8 +299,10 @@ void ConvolutionComputation::InitCudnn() { "No algorithms were returned by CUDNN."); const cudnnConvolutionBwdDataAlgoPerf_t& best_backward_data = backward_data_results[0]; + KALDI_ASSERT(best_backward_data.status == CUDNN_STATUS_SUCCESS); bwd_data_algo_ = best_backward_data.algo; delete [] backward_data_results; + ComputeTempSpaceSizes(); } #endif @@ -574,7 +604,7 @@ ConvolveBackwardBias(const MatrixBase &output_deriv, // This function, called only if we are not using the GPU, converts -// the params from KCWH format to WHKC format (which is more convenient +// the params from KWHC format to WHKC format (which is more convenient // when using the CPU. Note: K == channels-out, C == channels-in. void ConvolutionComputation::ConvertParams( const MatrixBase ¶ms, @@ -585,12 +615,27 @@ void ConvolutionComputation::ConvertParams( params_rearranged->NumRows() == c.filter_width * c.filter_height && params_rearranged->Stride() == params_rearranged->NumCols()); - // Reinterpret params as params_reinterpret which is of dimension KC * WH (instead of K * CWH). - SubMatrix params_reinterpret(params.Data(), - c.num_channels_out * c.num_channels_in, - c.filter_width * c.filter_height, - c.filter_width * c.filter_height); - params_rearranged->CopyFromMat(params_reinterpret, kTrans); + int32 num_rows_reinterpret = + c.num_channels_out * c.filter_width * c.filter_height, + num_cols_reinterpret = c.num_channels_in, + area = c.filter_width * c.filter_height; + + // Reinterpret params as params_reinterpret which is of dimension KWH * C + SubMatrix params_reinterpret( + params.Data(), + num_rows_reinterpret, num_cols_reinterpret, num_cols_reinterpret); + SubMatrix params_rearranged_reinterpret( + params_rearranged->Data(), + num_rows_reinterpret, num_cols_reinterpret, num_cols_reinterpret); + for (int32 k = 0; k < c.num_channels_out; k++) { + for (int32 wh = 0; wh < area; wh++) { + int32 params_row = k * area + wh, + params_rearranged_row = wh * c.num_channels_out + k; + SubVector src(params_reinterpret, params_row), + dest(*params_rearranged, params_rearranged_row); + dest.CopyFromVec(src); + } + } } diff --git a/src/nnet3/convolution-cudnn.h b/src/nnet3/convolution-cudnn.h index efacab0bfa1..2e875108e79 100644 --- a/src/nnet3/convolution-cudnn.h +++ b/src/nnet3/convolution-cudnn.h @@ -99,6 +99,7 @@ struct ConvolutionComputationConfig { void Write(std::ostream &os, bool binary) const; + // Note: Read() automatically calls ComputeOutputImageSize(). void Read(std::istream &is, bool binary); }; @@ -147,7 +148,7 @@ class ConvolutionComputation final { those familiar with CUDNN get surprised at the order @param [in] input NWHC fully-packed tensor, with NumRows() == N * W - @param [in] params KCWH fully-packed tensor, with NumRows() == K. + @param [in] params KWHC fully-packed tensor, with NumRows() == K. @param [in] bias vector of length K @param [out] output Pre-allocated NWHK fully-packed tensor, with N == NumRows() */ @@ -157,7 +158,7 @@ class ConvolutionComputation final { CuMatrixBase *output) const; /** - * @param [in] params KCWH fully-packed tensor, with NumRows() == K + * @param [in] params KWHC fully-packed tensor, with NumRows() == K * @param [in] output_deriv NWHK fully-packed tensor, with NumRows() == N * W * @param [out] input_deriv Pre-allocated NWHC fully-packed tensor, with * NumRows() == N * W @@ -171,8 +172,8 @@ class ConvolutionComputation final { * @param [in] input NWHC fully-packed tensor, with NumRows() == N * W. * @param [in] alpha * params_deriv := alpha * gradient_computed + params_deriv - * @param [in] params KCWH fully-packed tensor, with NumRows() == K - * @param [out] params_deriv Pre-allocated KCWH fully-packed tensor, + * @param [in] params KWHC fully-packed tensor, with NumRows() == K + * @param [out] params_deriv Pre-allocated KWHC fully-packed tensor, * with NumRows() == K. */ void ConvolveBackwardParams(const CuMatrixBase &output_deriv, @@ -235,8 +236,8 @@ class ConvolutionComputation final { // This function, called only if we are not using the GPU, converts - // the params from KCWH format to WHKC format (which is more convenient - // when using the CPU. params and params_rearranged must both be + // the params from KWHC format to WHKC format (which is more convenient + // when using the CPU). params and params_rearranged must both be // packed (Stride() == NumCols()), params must have num-rows equal to K // (num_channels_out_), and params_rearranged must have num-rows equal // to to WH (filter_width_ * filter_height_). From 22669f6e8fad32bc39bb9deab445c804d1895ebc Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Tue, 23 Oct 2018 03:20:38 -0400 Subject: [PATCH 15/22] Change filter type back to NCHW, since it supports more algos. Convert assertion to warning, although I am not sure this works, since if an algorithm fails to be found because of an out-of-memory error, it is likely to fail at training time for the same reason. --- src/nnet3/convolution-cudnn.cc | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/nnet3/convolution-cudnn.cc b/src/nnet3/convolution-cudnn.cc index 9442d85ebd1..d3476bdca2a 100644 --- a/src/nnet3/convolution-cudnn.cc +++ b/src/nnet3/convolution-cudnn.cc @@ -30,6 +30,15 @@ namespace { // static variables. const BaseFloat ONE(1.0); const BaseFloat ZERO(0.0); + +template +void CheckCorrectness(CudnnAlgoPerfT perf_results, const char* function) { + if (perf_results.status != CUDNN_STATUS_SUCCESS) { + KALDI_WARN << function << " had an error: " << + cudnnGetErrorString(perf_results.status) << ". Continuing with algo" << + perf_results.algo << " but results may not be ideal."; + } +} } @@ -198,7 +207,7 @@ void ConvolutionComputation::InitCudnn() { // smallest: num-channels-out, width, height, num-channels-in. CUDNN_SAFE_CALL( cudnnSetFilter4dDescriptor(params_desc_, CUDNN_DATA_BASEFLOAT, - CUDNN_TENSOR_NHWC, c.num_channels_out, + CUDNN_TENSOR_NCHW, c.num_channels_out, c.num_channels_in, c.filter_width, c.filter_height)); @@ -257,7 +266,7 @@ void ConvolutionComputation::InitCudnn() { KALDI_ASSERT(returned_algo_count > 0 && "No algorithms were returned by CUDNN."); const cudnnConvolutionFwdAlgoPerf_t& best_forward = forward_results[0]; - KALDI_ASSERT(best_forward.status == CUDNN_STATUS_SUCCESS); + CheckCorrectness(best_forward, "cudnnFindConvolutionForwardAlgorithm"); fwd_algo_ = best_forward.algo; delete [] forward_results; @@ -278,7 +287,8 @@ void ConvolutionComputation::InitCudnn() { "No algorithms were returned by CUDNN."); const cudnnConvolutionBwdFilterAlgoPerf_t& best_backward_filter = backward_filter_results[0]; - KALDI_ASSERT(best_backward_filter.status == CUDNN_STATUS_SUCCESS); + CheckCorrectness(best_backward_filter, + "cudnnFindConvolutionBackwardFilterAlgorithm"); bwd_filter_algo_ = best_backward_filter.algo; delete [] backward_filter_results; @@ -299,7 +309,8 @@ void ConvolutionComputation::InitCudnn() { "No algorithms were returned by CUDNN."); const cudnnConvolutionBwdDataAlgoPerf_t& best_backward_data = backward_data_results[0]; - KALDI_ASSERT(best_backward_data.status == CUDNN_STATUS_SUCCESS); + CheckCorrectness(best_backward_data, + "cudnnFindConvolutionBackwardDataAlgorithm"); bwd_data_algo_ = best_backward_data.algo; delete [] backward_data_results; From c958143aa6c876b4a1eefc5f0f561d540416908f Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Tue, 23 Oct 2018 23:49:31 -0400 Subject: [PATCH 16/22] Workaround cudnnSetTensor4dDescriptor's striding bug. Use cudnnSetTensor4dDescriptor with strides we calculate ourselves instead. --- src/nnet3/convolution-cudnn-test.cc | 2 +- src/nnet3/convolution-cudnn.cc | 22 ++++++++++++++-------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/nnet3/convolution-cudnn-test.cc b/src/nnet3/convolution-cudnn-test.cc index d00629fb6b1..a9c8b0b3939 100644 --- a/src/nnet3/convolution-cudnn-test.cc +++ b/src/nnet3/convolution-cudnn-test.cc @@ -36,7 +36,7 @@ static void GetRandomConfig(ConvolutionComputationConfig *config) { config->filter_stride_vertical = RandInt(1, 2); config->filter_stride_horizontal = RandInt(1, 2); config->filter_dilation_vertical = RandInt(1, 2); - config->filter_dilation_horizontal = 1; + config->filter_dilation_horizontal = RandInt(1, 2); config->input_image_height = RandInt(10, 20); config->input_image_width = RandInt(10, 20); diff --git a/src/nnet3/convolution-cudnn.cc b/src/nnet3/convolution-cudnn.cc index d3476bdca2a..157e8ef6d7d 100644 --- a/src/nnet3/convolution-cudnn.cc +++ b/src/nnet3/convolution-cudnn.cc @@ -181,11 +181,14 @@ void ConvolutionComputation::InitCudnn() { // relative to what the CUDNN interface specifies; this is because Kaldi's // notion of what is height vs. width is opposite to CUDNN's. (There // are good reasons for this). + int in_dims[4] = {c.num_images, c.num_channels_in, c.input_image_width, + c.input_image_height}; + int in_stride[4] = {c.num_channels_in * c.input_image_width * c.input_image_height, + c.input_image_width * c.input_image_height, + c.input_image_height, 1}; CUDNN_SAFE_CALL( - cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NHWC, - CUDNN_DATA_BASEFLOAT, c.num_images, - c.num_channels_in, c.input_image_width, - c.input_image_height)); + cudnnSetTensorNdDescriptor(input_desc_, CUDNN_DATA_BASEFLOAT, 4, in_dims, + in_stride)); // Again: width and height are swapped. CUDNN_SAFE_CALL( cudnnSetConvolution2dDescriptor( @@ -228,11 +231,14 @@ void ConvolutionComputation::InitCudnn() { // These two member functions depend only on input_desc_, // conv_desc_, and params_desc_, so they are safe to call now. + int out_dims[4] = {c.num_images, c.num_channels_out, c.output_image_width, + c.output_image_height}; + int out_stride[4] = {c.num_channels_out * c.output_image_width * c.output_image_height, + c.output_image_width * c.output_image_height, + c.output_image_height, 1}; CUDNN_SAFE_CALL( - cudnnSetTensor4dDescriptor(output_desc_, CUDNN_TENSOR_NHWC, - CUDNN_DATA_BASEFLOAT, c.num_images, - c.num_channels_out, kaldi_width_cudnn_height, - kaldi_height_cudnn_width)); + cudnnSetTensorNdDescriptor(output_desc_, CUDNN_DATA_BASEFLOAT, 4, out_dims, + out_stride)); // We pad the bias with leading dims of 1, since CUDNN's tensors appear to // need a dimension of at least 3. From 74114d0f416fb1a524499900b3390b55156f5b58 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 25 Oct 2018 19:05:24 -0400 Subject: [PATCH 17/22] [src] Fix various bugs. --- src/nnet3/convolution-cudnn-test.cc | 103 +++++++++++++++++++++++++++ src/nnet3/convolution-cudnn.cc | 104 +++++++++++++--------------- src/nnet3/convolution-cudnn.h | 20 +++--- 3 files changed, 161 insertions(+), 66 deletions(-) diff --git a/src/nnet3/convolution-cudnn-test.cc b/src/nnet3/convolution-cudnn-test.cc index a9c8b0b3939..f1e0a9f18c0 100644 --- a/src/nnet3/convolution-cudnn-test.cc +++ b/src/nnet3/convolution-cudnn-test.cc @@ -65,6 +65,101 @@ void TestConvolutionComputationConfig() { } } + +void ConvolveForwardWithGpu( + const ConvolutionComputation &computation, + const MatrixBase &input, + const MatrixBase ¶ms, + const VectorBase &bias, + MatrixBase *output) { + CuMatrix input_gpu(input.NumRows(), input.NumCols(), + kUndefined, kStrideEqualNumCols); + input_gpu.CopyFromMat(input); + CuMatrix params_gpu(params.NumRows(), params.NumCols(), + kUndefined, kStrideEqualNumCols); + params_gpu.CopyFromMat(params); + CuVector bias_gpu(bias); + CuMatrix output_gpu(output->NumRows(), output->NumCols(), + kSetZero, kStrideEqualNumCols); + computation.ConvolveForward(input_gpu, params_gpu, bias_gpu, &output_gpu); + output->CopyFromMat(output_gpu); +} + + + +// Tests that the CPU version gives the expected results. The output +// has to be inspected by a human. +void TestConvolutionComputationForward( + const ConvolutionComputation &computation, + bool use_gpu) { + const ConvolutionComputationConfig &c = computation.Config(); + + Matrix input(c.num_images * c.input_image_width, + c.input_image_height * c.num_channels_in, + kSetZero, kStrideEqualNumCols), + output(c.num_images * c.output_image_width, + c.output_image_height * c.num_channels_out, + kSetZero, kStrideEqualNumCols), + params(c.num_channels_out, + c.num_channels_in * c.filter_width * c.filter_height, + kSetZero, kStrideEqualNumCols); + + // One parameter and one channel of input pixel will be nonzero-- for testing purposes. + + int32 n = RandInt(0, c.num_images - 1), + input_w = RandInt(0, c.input_image_width - 1), + input_h = RandInt(0, c.input_image_height - 1), + input_c = RandInt(0, c.num_channels_in - 1); + input(n * c.input_image_width + input_w, + input_h * c.num_channels_in + input_c) = 2.0; + + int32 output_c = RandInt(0, c.num_channels_out - 1), + filter_w = RandInt(0, c.filter_width - 1), + filter_h = RandInt(0, c.filter_height - 1); + + params(output_c, + input_c * c.filter_width * c.filter_height + + filter_w * c.filter_height + + filter_h) = 3.0; + + Vector bias(c.num_channels_out); + + if (use_gpu) { + ConvolveForwardWithGpu(computation, input, params, bias, &output); + } else { + computation.ConvolveForward(input, params, bias, + &output); + } + + KALDI_LOG << "Have nonzero input for n=" << n + << ", w=" << input_w << ", h=" << input_h + << ", input_channel=" << input_c; + KALDI_LOG << "Have nonzero filter for w=" + << filter_w << ", h=" << filter_h + << ", output_channel=" << output_c; + bool found_nonzero = false; + for (int32 n = 0; n < c.num_images; n++) { + for (int32 w = 0; w < c.output_image_width; w++) { + for (int32 h = 0; h < c.output_image_height; h++) { + for (int32 ch = 0; ch < c.num_channels_out; ch++) { + BaseFloat val = output(n * c.output_image_width + w, + h * c.num_channels_out + ch); + if (val != 0.0) { + found_nonzero = true; + KALDI_LOG << "Found nonzero value " << val << " for image n=" + << n << ", w=" << w << ", h=" << h + << ", output_channel=" << ch; + } + } + } + } + } + if (!found_nonzero) + KALDI_WARN << "Found no nonzero value, sum is " << output.Sum(); + + +} + void TestConvolutionComputation() { for (int32 i = 0; i < 100; i++) { ConvolutionComputationConfig config; @@ -88,6 +183,14 @@ void TestConvolutionComputation() { std::ostringstream os2; computation2.Write(os2, binary); KALDI_ASSERT(os.str() == os2.str()); + KALDI_LOG << "About to test without GPU."; + TestConvolutionComputationForward(computation2, false); +#if HAVE_CUDA == 1 + if (CuDevice::Instantiate().Enabled()) { + KALDI_LOG << "About to test with GPU"; + TestConvolutionComputationForward(computation2, true); + } +#endif } } diff --git a/src/nnet3/convolution-cudnn.cc b/src/nnet3/convolution-cudnn.cc index 157e8ef6d7d..a29c46ec04a 100644 --- a/src/nnet3/convolution-cudnn.cc +++ b/src/nnet3/convolution-cudnn.cc @@ -33,11 +33,9 @@ const BaseFloat ZERO(0.0); template void CheckCorrectness(CudnnAlgoPerfT perf_results, const char* function) { - if (perf_results.status != CUDNN_STATUS_SUCCESS) { - KALDI_WARN << function << " had an error: " << - cudnnGetErrorString(perf_results.status) << ". Continuing with algo" << - perf_results.algo << " but results may not be ideal."; - } + if (perf_results.status != CUDNN_STATUS_SUCCESS) + KALDI_ERR << function << " had an error: " << + cudnnGetErrorString(perf_results.status); } } @@ -181,6 +179,7 @@ void ConvolutionComputation::InitCudnn() { // relative to what the CUDNN interface specifies; this is because Kaldi's // notion of what is height vs. width is opposite to CUDNN's. (There // are good reasons for this). + // We use cudnnSetTensorNdDescriptor because of bugs in cudnnSetTensor4dDescriptor. int in_dims[4] = {c.num_images, c.num_channels_in, c.input_image_width, c.input_image_height}; int in_stride[4] = {c.num_channels_in * c.input_image_width * c.input_image_height, @@ -202,12 +201,13 @@ void ConvolutionComputation::InitCudnn() { // Set dimensions of the filters (linear parameters). // Again: width and height are swapped. Per the CUDNN documentation at // https://docs.nvidia.com/deeplearning/sdk/pdf/cuDNN-Developer-Guide.pdf for - // cudnnSetFilter4dDescriptor, setting CUDNN_TENSOR_NHWC as the layout - // corresponds to KSRC, meaning: num-channels-out, height, width, num-channels-in, + // cudnnSetFilter4dDescriptor, setting CUDNN_TENSOR_NCHW as the layout + // corresponds to KCRS, meaning: num-channels-out, num-channels-in, height, width, // where 'height' and 'width' are the filter height and width respectively (e.g. 3 // and 3 for a 3x3 patch); and these are swapped w.r.t. Kaldi's notion of height and // width, so as far as Kaldi is concerned, the strides are, from largest to - // smallest: num-channels-out, width, height, num-channels-in. + // smallest: num-channels-out, width, height, num-channels-in: so as far + // as Kaldi is concerned the layout is KCWH (== KCSR, in their notation). CUDNN_SAFE_CALL( cudnnSetFilter4dDescriptor(params_desc_, CUDNN_DATA_BASEFLOAT, CUDNN_TENSOR_NCHW, c.num_channels_out, @@ -397,13 +397,14 @@ ConvolveForward(const CuMatrixBase &input, input.Stride() == input.NumCols() && params.NumRows() == c.num_channels_out && params.NumCols() == c.num_channels_in * c.filter_height * c.filter_width && - params.Stride() == params.NumCols() && + params.Stride() == params.NumCols()); + KALDI_ASSERT( bias.Dim() == c.num_channels_out && - output->NumRows() == c.num_images * c.input_image_height && - output->NumCols() == c.input_image_width * c.num_channels_out && + output->NumRows() == c.num_images * c.output_image_width && + output->NumCols() == c.output_image_height * c.num_channels_out && output->Stride() == output->NumCols()); -#ifdef HAVE_CUDNN +#if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { CuVector temp_space(temp_space_required_forward_ / sizeof(BaseFloat), kUndefined); @@ -448,17 +449,18 @@ ConvolveForward(const MatrixBase &input, input.Stride() == input.NumCols() && params.NumRows() == c.num_channels_out && params.NumCols() == c.num_channels_in * c.filter_height * c.filter_width && - params.Stride() == params.NumCols() && + params.Stride() == params.NumCols()); + KALDI_ASSERT( bias.Dim() == c.num_channels_out && - output->NumRows() == c.num_images * c.input_image_height && - output->NumCols() == c.input_image_width * c.num_channels_out && + output->NumRows() == c.num_images * c.output_image_width && + output->NumCols() == c.output_image_height * c.num_channels_out && output->Stride() == output->NumCols()); { // Deal with the bias. SubMatrix output_rearranged( output->Data(), - c.num_images * c.input_image_width * c.input_image_height, + c.num_images * c.output_image_width * c.output_image_height, c.num_channels_out, c.num_channels_out); output_rearranged.CopyRowsFromVec(bias); } @@ -468,19 +470,28 @@ ConvolveForward(const MatrixBase &input, kUndefined, kStrideEqualNumCols); ConvertParams(params, ¶ms_rearranged); + // The strides in 'input' and 'output' respectively from a certain pixel of one + // image to the same pixel in another image. + int32 input_image_stride = + c.input_image_width * c.num_channels_in * c.input_image_height, + output_image_stride = + c.output_image_width * c.num_channels_out * c.output_image_height; + // We're using variable names w (as in width) for horizontal positions and h // (as in height) for vertical positions. This is perhaps not ideal. for (int32 output_w = 0; output_w < c.output_image_width; output_w++) { for (int32 output_h = 0; output_h < c.output_image_height; output_h++) { for (int32 filter_h = 0; filter_h < c.filter_height; filter_h++) { - int32 filter_h_flipped = c.filter_height - 1 - filter_h; + //int32 filter_h_flipped = c.filter_height - 1 - filter_h; + int32 filter_h_flipped = filter_h; // we don't flip. int32 input_h = output_h * c.filter_stride_vertical - c.zero_padding_vertical + filter_h * c.filter_dilation_vertical; if (input_h < 0 || input_h >= c.input_image_height) continue; for (int32 filter_w = 0; filter_w < c.filter_width; filter_w++) { - int32 filter_w_flipped = c.filter_width - 1 - filter_w; + // int32 filter_w_flipped = c.filter_width - 1 - filter_w; + int32 filter_w_flipped = filter_w; // we don't flip. int32 input_w = output_w * c.filter_stride_horizontal - c.zero_padding_horizontal + filter_w * c.filter_dilation_horizontal; @@ -493,17 +504,20 @@ ConvolveForward(const MatrixBase &input, SubMatrix this_params(params_data, c.num_channels_out, c.num_channels_in, c.num_channels_in); - const BaseFloat *input_data = input.Data() + - input_w * c.input_image_height * c.num_channels_in + - input_h * c.num_channels_in; + const BaseFloat *input_data = + input.RowData(input_w) + input_h * c.num_channels_in; SubMatrix this_input_pixel(input_data, c.num_images, c.num_channels_in, - c.num_channels_in); - SubMatrix this_output_pixel(input_data, + input_image_stride); + + + const BaseFloat *output_data = + output->RowData(output_w) + output_h * c.num_channels_out; + SubMatrix this_output_pixel(output_data, c.num_images, - c.num_channels_in, - c.num_channels_in); + c.num_channels_out, + output_image_stride); this_output_pixel.AddMatMat(1.0, this_input_pixel, kNoTrans, this_params, kTrans, 1.0); } @@ -516,7 +530,7 @@ void ConvolutionComputation:: ConvolveBackwardData(const CuMatrixBase ¶ms, const CuMatrixBase &output_deriv, CuMatrixBase *input_deriv) const { -#ifdef HAVE_CUDNN +#if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { CuVector temp_space(temp_space_required_backward_data_ / sizeof(BaseFloat), kUndefined); @@ -555,9 +569,9 @@ ConvolveBackwardParams(const CuMatrixBase &output_deriv, const CuMatrixBase &input, BaseFloat alpha, CuMatrixBase *params_deriv) const { -#ifdef HAVE_CUDNN +#if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { - CuVector temp_space(temp_space_required_backward_params_ / + CuVector temp_space(temp_space_required_backward_filter_ / sizeof(BaseFloat), kUndefined); CUDNN_SAFE_CALL(cudnnConvolutionBackwardFilter( CuDevice::Instantiate().GetCudnnHandle(), @@ -595,7 +609,7 @@ void ConvolutionComputation:: ConvolveBackwardBias(const CuMatrixBase &output_deriv, BaseFloat alpha, CuVectorBase *bias_deriv) const { -#ifdef HAVE_CUDNN +#if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { CUDNN_SAFE_CALL(cudnnConvolutionBackwardBias( CuDevice::Instantiate().GetCudnnHandle(), @@ -627,32 +641,12 @@ void ConvolutionComputation::ConvertParams( const MatrixBase ¶ms, MatrixBase *params_rearranged) const { const ConvolutionComputationConfig &c = config_; - KALDI_ASSERT(params.NumRows() == c.num_channels_out && - params.Stride() == params.NumCols() && - params_rearranged->NumRows() == c.filter_width * c.filter_height && - params_rearranged->Stride() == params_rearranged->NumCols()); - - int32 num_rows_reinterpret = - c.num_channels_out * c.filter_width * c.filter_height, - num_cols_reinterpret = c.num_channels_in, - area = c.filter_width * c.filter_height; - - // Reinterpret params as params_reinterpret which is of dimension KWH * C - SubMatrix params_reinterpret( - params.Data(), - num_rows_reinterpret, num_cols_reinterpret, num_cols_reinterpret); - SubMatrix params_rearranged_reinterpret( - params_rearranged->Data(), - num_rows_reinterpret, num_cols_reinterpret, num_cols_reinterpret); - for (int32 k = 0; k < c.num_channels_out; k++) { - for (int32 wh = 0; wh < area; wh++) { - int32 params_row = k * area + wh, - params_rearranged_row = wh * c.num_channels_out + k; - SubVector src(params_reinterpret, params_row), - dest(*params_rearranged, params_rearranged_row); - dest.CopyFromVec(src); - } - } + // Reinterpret params as params_reinterpret which is of dimension KC * WH (instead of K * CWH). + SubMatrix params_reinterpret(params.Data(), + c.num_channels_out * c.num_channels_in, + c.filter_width * c.filter_height, + c.filter_width * c.filter_height); + params_rearranged->CopyFromMat(params_reinterpret, kTrans); } diff --git a/src/nnet3/convolution-cudnn.h b/src/nnet3/convolution-cudnn.h index 2e875108e79..a01887b05b8 100644 --- a/src/nnet3/convolution-cudnn.h +++ b/src/nnet3/convolution-cudnn.h @@ -122,6 +122,8 @@ class ConvolutionComputation final { // This constructor may be used prior to calling Read(). ConvolutionComputation(); + const ConvolutionComputationConfig &Config() const { return config_; } + ~ConvolutionComputation(); /* @@ -136,11 +138,7 @@ class ConvolutionComputation final { filter_width (for filter parameters). and the order of letters is from highest to lowest stride, e.g in - NWHC, N would have the highest stride, and C a stride of 1. - - - explanation for those variable names in the documentation for this - class above. Variables that come first have the higher stride. + In NWHC, N would have the highest stride, and C a stride of 1. Caution: for convenience, given the way nnet3 works, we flip the notion of height and width that CUDNN uses, so our height is CUDNN's width, and vice @@ -148,9 +146,9 @@ class ConvolutionComputation final { those familiar with CUDNN get surprised at the order @param [in] input NWHC fully-packed tensor, with NumRows() == N * W - @param [in] params KWHC fully-packed tensor, with NumRows() == K. + @param [in] params KCWH fully-packed tensor, with NumRows() == K. @param [in] bias vector of length K - @param [out] output Pre-allocated NWHK fully-packed tensor, with N == NumRows() + @param [out] output Pre-allocated NWHK fully-packed tensor, with NumRows() == N * W. */ void ConvolveForward(const CuMatrixBase &input, const CuMatrixBase ¶ms, @@ -158,7 +156,7 @@ class ConvolutionComputation final { CuMatrixBase *output) const; /** - * @param [in] params KWHC fully-packed tensor, with NumRows() == K + * @param [in] params KCWH fully-packed tensor, with NumRows() == K * @param [in] output_deriv NWHK fully-packed tensor, with NumRows() == N * W * @param [out] input_deriv Pre-allocated NWHC fully-packed tensor, with * NumRows() == N * W @@ -172,8 +170,8 @@ class ConvolutionComputation final { * @param [in] input NWHC fully-packed tensor, with NumRows() == N * W. * @param [in] alpha * params_deriv := alpha * gradient_computed + params_deriv - * @param [in] params KWHC fully-packed tensor, with NumRows() == K - * @param [out] params_deriv Pre-allocated KWHC fully-packed tensor, + * @param [in] params KCWH fully-packed tensor, with NumRows() == K + * @param [out] params_deriv Pre-allocated KCWH fully-packed tensor, * with NumRows() == K. */ void ConvolveBackwardParams(const CuMatrixBase &output_deriv, @@ -236,7 +234,7 @@ class ConvolutionComputation final { // This function, called only if we are not using the GPU, converts - // the params from KWHC format to WHKC format (which is more convenient + // the params from KCWH format to WHKC format (which is more convenient // when using the CPU). params and params_rearranged must both be // packed (Stride() == NumCols()), params must have num-rows equal to K // (num_channels_out_), and params_rearranged must have num-rows equal From 0bed8aa8e9f3887eac6f311eb13643c7c979ec61 Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Fri, 26 Oct 2018 12:03:05 -0400 Subject: [PATCH 18/22] Don't use cudnnConvolutionBiasActivationForward. It supports only the precomputed implicit GEMM implementation of convolution. --- src/nnet3/convolution-cudnn.cc | 70 +++++++++++++++------------------- src/nnet3/convolution-cudnn.h | 1 - 2 files changed, 31 insertions(+), 40 deletions(-) diff --git a/src/nnet3/convolution-cudnn.cc b/src/nnet3/convolution-cudnn.cc index a29c46ec04a..db20be38cf6 100644 --- a/src/nnet3/convolution-cudnn.cc +++ b/src/nnet3/convolution-cudnn.cc @@ -173,7 +173,6 @@ void ConvolutionComputation::InitCudnn() { CUDNN_SAFE_CALL(cudnnCreateFilterDescriptor(¶ms_desc_)); CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&bias_desc_)); CUDNN_SAFE_CALL(cudnnCreateConvolutionDescriptor(&conv_desc_)); - CUDNN_SAFE_CALL(cudnnCreateActivationDescriptor(&activation_desc_)); // Caution: in the following call, the 'height' and 'width' are swapped // relative to what the CUDNN interface specifies; this is because Kaldi's @@ -181,13 +180,13 @@ void ConvolutionComputation::InitCudnn() { // are good reasons for this). // We use cudnnSetTensorNdDescriptor because of bugs in cudnnSetTensor4dDescriptor. int in_dims[4] = {c.num_images, c.num_channels_in, c.input_image_width, - c.input_image_height}; + c.input_image_height}; int in_stride[4] = {c.num_channels_in * c.input_image_width * c.input_image_height, - c.input_image_width * c.input_image_height, - c.input_image_height, 1}; + c.input_image_width * c.input_image_height, + c.input_image_height, 1}; CUDNN_SAFE_CALL( cudnnSetTensorNdDescriptor(input_desc_, CUDNN_DATA_BASEFLOAT, 4, in_dims, - in_stride)); + in_stride)); // Again: width and height are swapped. CUDNN_SAFE_CALL( cudnnSetConvolution2dDescriptor( @@ -232,35 +231,30 @@ void ConvolutionComputation::InitCudnn() { // These two member functions depend only on input_desc_, // conv_desc_, and params_desc_, so they are safe to call now. int out_dims[4] = {c.num_images, c.num_channels_out, c.output_image_width, - c.output_image_height}; + c.output_image_height}; int out_stride[4] = {c.num_channels_out * c.output_image_width * c.output_image_height, - c.output_image_width * c.output_image_height, - c.output_image_height, 1}; + c.output_image_width * c.output_image_height, + c.output_image_height, 1}; CUDNN_SAFE_CALL( cudnnSetTensorNdDescriptor(output_desc_, CUDNN_DATA_BASEFLOAT, 4, out_dims, - out_stride)); + out_stride)); // We pad the bias with leading dims of 1, since CUDNN's tensors appear to // need a dimension of at least 3. - int bias_dims[3] = {1, 1, c.num_channels_out}; - int bias_stride[3] = {c.num_channels_out, c.num_channels_out, 1}; + int bias_dims[4] = {1, c.num_channels_out, 1, 1}; + int bias_stride[4] = {c.num_channels_out, 1, 1, 1}; CUDNN_SAFE_CALL( - cudnnSetTensorNdDescriptor(bias_desc_, CUDNN_DATA_BASEFLOAT, 3, + cudnnSetTensorNdDescriptor(bias_desc_, CUDNN_DATA_BASEFLOAT, 4, bias_dims, bias_stride)); - const double DONT_CARE = 0; - CUDNN_SAFE_CALL( - cudnnSetActivationDescriptor(activation_desc_, CUDNN_ACTIVATION_IDENTITY, - CUDNN_PROPAGATE_NAN, DONT_CARE)); - int32 requested_algo_count, returned_algo_count; CUDNN_SAFE_CALL(cudnnGetConvolutionForwardAlgorithmMaxCount( - CuDevice::Instantiate().GetCudnnHandle(), &requested_algo_count)); + GetCudnnHandle(), &requested_algo_count)); cudnnConvolutionFwdAlgoPerf_t *forward_results = new cudnnConvolutionFwdAlgoPerf_t[requested_algo_count]; CUDNN_SAFE_CALL(cudnnFindConvolutionForwardAlgorithm( - CuDevice::Instantiate().GetCudnnHandle(), + GetCudnnHandle(), input_desc_, params_desc_, conv_desc_, @@ -277,11 +271,11 @@ void ConvolutionComputation::InitCudnn() { delete [] forward_results; CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( - CuDevice::Instantiate().GetCudnnHandle(), &requested_algo_count)); + GetCudnnHandle(), &requested_algo_count)); cudnnConvolutionBwdFilterAlgoPerf_t *backward_filter_results = new cudnnConvolutionBwdFilterAlgoPerf_t[requested_algo_count]; CUDNN_SAFE_CALL(cudnnFindConvolutionBackwardFilterAlgorithm( - CuDevice::Instantiate().GetCudnnHandle(), + GetCudnnHandle(), input_desc_, output_desc_, conv_desc_, @@ -294,16 +288,16 @@ void ConvolutionComputation::InitCudnn() { const cudnnConvolutionBwdFilterAlgoPerf_t& best_backward_filter = backward_filter_results[0]; CheckCorrectness(best_backward_filter, - "cudnnFindConvolutionBackwardFilterAlgorithm"); + "cudnnFindConvolutionBackwardFilterAlgorithm"); bwd_filter_algo_ = best_backward_filter.algo; delete [] backward_filter_results; CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( - CuDevice::Instantiate().GetCudnnHandle(), &requested_algo_count)); + GetCudnnHandle(), &requested_algo_count)); cudnnConvolutionBwdDataAlgoPerf_t *backward_data_results = new cudnnConvolutionBwdDataAlgoPerf_t[requested_algo_count]; CUDNN_SAFE_CALL(cudnnFindConvolutionBackwardDataAlgorithm( - CuDevice::Instantiate().GetCudnnHandle(), + GetCudnnHandle(), params_desc_, output_desc_, conv_desc_, @@ -316,7 +310,7 @@ void ConvolutionComputation::InitCudnn() { const cudnnConvolutionBwdDataAlgoPerf_t& best_backward_data = backward_data_results[0]; CheckCorrectness(best_backward_data, - "cudnnFindConvolutionBackwardDataAlgorithm"); + "cudnnFindConvolutionBackwardDataAlgorithm"); bwd_data_algo_ = best_backward_data.algo; delete [] backward_data_results; @@ -327,7 +321,7 @@ void ConvolutionComputation::InitCudnn() { #if HAVE_CUDA == 1 void ConvolutionComputation::ComputeTempSpaceSizes() { CUDNN_SAFE_CALL(cudnnGetConvolutionForwardWorkspaceSize( - CuDevice::Instantiate().GetCudnnHandle(), + GetCudnnHandle(), input_desc_, params_desc_, conv_desc_, @@ -336,7 +330,7 @@ void ConvolutionComputation::ComputeTempSpaceSizes() { &temp_space_required_forward_)); CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize( - CuDevice::Instantiate().GetCudnnHandle(), + GetCudnnHandle(), params_desc_, output_desc_, conv_desc_, @@ -345,7 +339,7 @@ void ConvolutionComputation::ComputeTempSpaceSizes() { &temp_space_required_backward_data_)); CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize( - CuDevice::Instantiate().GetCudnnHandle(), + GetCudnnHandle(), input_desc_, output_desc_, conv_desc_, @@ -362,7 +356,6 @@ void ConvolutionComputation::DestroyCudnn() { CUDNN_SAFE_CALL(cudnnDestroyFilterDescriptor(params_desc_)); CUDNN_SAFE_CALL(cudnnDestroyTensorDescriptor(bias_desc_)); CUDNN_SAFE_CALL(cudnnDestroyConvolutionDescriptor(conv_desc_)); - CUDNN_SAFE_CALL(cudnnDestroyActivationDescriptor(activation_desc_)); } #endif @@ -408,8 +401,9 @@ ConvolveForward(const CuMatrixBase &input, if (CuDevice::Instantiate().Enabled()) { CuVector temp_space(temp_space_required_forward_ / sizeof(BaseFloat), kUndefined); - CUDNN_SAFE_CALL(cudnnConvolutionBiasActivationForward( - CuDevice::Instantiate().GetCudnnHandle(), + + CUDNN_SAFE_CALL(cudnnConvolutionForward( + GetCudnnHandle(), &ONE, input_desc_, input.Data(), @@ -421,12 +415,10 @@ ConvolveForward(const CuMatrixBase &input, temp_space.Dim() * sizeof(BaseFloat), &ZERO, output_desc_, - output->Data(), - bias_desc_, - bias.Data(), - activation_desc_, - output_desc_, output->Data())); + CUDNN_SAFE_CALL(cudnnAddTensor(GetCudnnHandle(), + &ONE, bias_desc_, bias.Data(), &ONE, + output_desc_, output->Data())); } else #endif { @@ -535,7 +527,7 @@ ConvolveBackwardData(const CuMatrixBase ¶ms, CuVector temp_space(temp_space_required_backward_data_ / sizeof(BaseFloat), kUndefined); CUDNN_SAFE_CALL(cudnnConvolutionBackwardData( - CuDevice::Instantiate().GetCudnnHandle(), + GetCudnnHandle(), &ONE, params_desc_, params.Data(), @@ -574,7 +566,7 @@ ConvolveBackwardParams(const CuMatrixBase &output_deriv, CuVector temp_space(temp_space_required_backward_filter_ / sizeof(BaseFloat), kUndefined); CUDNN_SAFE_CALL(cudnnConvolutionBackwardFilter( - CuDevice::Instantiate().GetCudnnHandle(), + GetCudnnHandle(), &alpha, input_desc_, input.Data(), @@ -612,7 +604,7 @@ ConvolveBackwardBias(const CuMatrixBase &output_deriv, #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { CUDNN_SAFE_CALL(cudnnConvolutionBackwardBias( - CuDevice::Instantiate().GetCudnnHandle(), + GetCudnnHandle(), &alpha, output_desc_, output_deriv.Data(), diff --git a/src/nnet3/convolution-cudnn.h b/src/nnet3/convolution-cudnn.h index a01887b05b8..22dc8535a4c 100644 --- a/src/nnet3/convolution-cudnn.h +++ b/src/nnet3/convolution-cudnn.h @@ -258,7 +258,6 @@ class ConvolutionComputation final { cudnnFilterDescriptor_t params_desc_; cudnnTensorDescriptor_t bias_desc_; cudnnConvolutionDescriptor_t conv_desc_; - cudnnActivationDescriptor_t activation_desc_; cudnnConvolutionFwdAlgo_t fwd_algo_; cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo_; From c546716de851057d67ffd77280fb9214c41037d9 Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Fri, 26 Oct 2018 18:00:48 -0400 Subject: [PATCH 19/22] Explain bias dimensions. --- src/nnet3/convolution-cudnn.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/nnet3/convolution-cudnn.cc b/src/nnet3/convolution-cudnn.cc index db20be38cf6..c1b300f796e 100644 --- a/src/nnet3/convolution-cudnn.cc +++ b/src/nnet3/convolution-cudnn.cc @@ -239,8 +239,10 @@ void ConvolutionComputation::InitCudnn() { cudnnSetTensorNdDescriptor(output_desc_, CUDNN_DATA_BASEFLOAT, 4, out_dims, out_stride)); - // We pad the bias with leading dims of 1, since CUDNN's tensors appear to - // need a dimension of at least 3. + // Since the output tensor shape is NKHW, we need the bias to be + // four-dimensional and the length of each dimension of the bias + // equal to either one or the output tensor's corresponding + // length. Singleton dimensions are broadcasted. int bias_dims[4] = {1, c.num_channels_out, 1, 1}; int bias_stride[4] = {c.num_channels_out, 1, 1, 1}; CUDNN_SAFE_CALL( From 450f491c786485033bd48ce040afbe73f58e5192 Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Sat, 27 Oct 2018 00:25:01 -0400 Subject: [PATCH 20/22] Make bias optional. The ConvolutionComputation class does not know whether or not the bias is optional at initialization time. It will simply avoid using the bias if it is a nullptr. --- src/nnet3/convolution-cudnn-test.cc | 4 ++-- src/nnet3/convolution-cudnn.cc | 28 ++++++++++++++++++---------- src/nnet3/convolution-cudnn.h | 4 ++-- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/nnet3/convolution-cudnn-test.cc b/src/nnet3/convolution-cudnn-test.cc index f1e0a9f18c0..b51cb04ce82 100644 --- a/src/nnet3/convolution-cudnn-test.cc +++ b/src/nnet3/convolution-cudnn-test.cc @@ -81,7 +81,7 @@ void ConvolveForwardWithGpu( CuVector bias_gpu(bias); CuMatrix output_gpu(output->NumRows(), output->NumCols(), kSetZero, kStrideEqualNumCols); - computation.ConvolveForward(input_gpu, params_gpu, bias_gpu, &output_gpu); + computation.ConvolveForward(input_gpu, params_gpu, &bias_gpu, &output_gpu); output->CopyFromMat(output_gpu); } @@ -127,7 +127,7 @@ void TestConvolutionComputationForward( if (use_gpu) { ConvolveForwardWithGpu(computation, input, params, bias, &output); } else { - computation.ConvolveForward(input, params, bias, + computation.ConvolveForward(input, params, &bias, &output); } diff --git a/src/nnet3/convolution-cudnn.cc b/src/nnet3/convolution-cudnn.cc index c1b300f796e..d291387c350 100644 --- a/src/nnet3/convolution-cudnn.cc +++ b/src/nnet3/convolution-cudnn.cc @@ -382,7 +382,7 @@ void ConvolutionComputation::Read(std::istream &is, bool binary) { void ConvolutionComputation:: ConvolveForward(const CuMatrixBase &input, const CuMatrixBase ¶ms, - const CuVectorBase &bias, + const CuVectorBase *bias, CuMatrixBase *output) const { const ConvolutionComputationConfig &c = config_; // Check some dimensions. @@ -394,7 +394,7 @@ ConvolveForward(const CuMatrixBase &input, params.NumCols() == c.num_channels_in * c.filter_height * c.filter_width && params.Stride() == params.NumCols()); KALDI_ASSERT( - bias.Dim() == c.num_channels_out && + (bias == nullptr || bias->Dim() == c.num_channels_out) && output->NumRows() == c.num_images * c.output_image_width && output->NumCols() == c.output_image_height * c.num_channels_out && output->Stride() == output->NumCols()); @@ -418,13 +418,15 @@ ConvolveForward(const CuMatrixBase &input, &ZERO, output_desc_, output->Data())); - CUDNN_SAFE_CALL(cudnnAddTensor(GetCudnnHandle(), - &ONE, bias_desc_, bias.Data(), &ONE, - output_desc_, output->Data())); + if (bias != nullptr) { + CUDNN_SAFE_CALL(cudnnAddTensor(GetCudnnHandle(), + &ONE, bias_desc_, bias->Data(), &ONE, + output_desc_, output->Data())); + } } else #endif { - ConvolveForward(input.Mat(), params.Mat(), bias.Vec(), + ConvolveForward(input.Mat(), params.Mat(), &(bias->Vec()), &(output->Mat())); } } @@ -433,7 +435,7 @@ ConvolveForward(const CuMatrixBase &input, void ConvolutionComputation:: ConvolveForward(const MatrixBase &input, const MatrixBase ¶ms, - const VectorBase &bias, + const VectorBase *bias, MatrixBase *output) const { const ConvolutionComputationConfig &c = config_; // Check some dimensions. @@ -445,18 +447,18 @@ ConvolveForward(const MatrixBase &input, params.NumCols() == c.num_channels_in * c.filter_height * c.filter_width && params.Stride() == params.NumCols()); KALDI_ASSERT( - bias.Dim() == c.num_channels_out && + (bias != nullptr || bias->Dim() == c.num_channels_out) && output->NumRows() == c.num_images * c.output_image_width && output->NumCols() == c.output_image_height * c.num_channels_out && output->Stride() == output->NumCols()); - { // Deal with the bias. + if (bias != nullptr) { // Deal with the bias. SubMatrix output_rearranged( output->Data(), c.num_images * c.output_image_width * c.output_image_height, c.num_channels_out, c.num_channels_out); - output_rearranged.CopyRowsFromVec(bias); + output_rearranged.CopyRowsFromVec(*bias); } Matrix params_rearranged(c.filter_width * c.filter_height, @@ -603,6 +605,9 @@ void ConvolutionComputation:: ConvolveBackwardBias(const CuMatrixBase &output_deriv, BaseFloat alpha, CuVectorBase *bias_deriv) const { + if (bias_deriv == nullptr) { + return; + } #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { CUDNN_SAFE_CALL(cudnnConvolutionBackwardBias( @@ -624,6 +629,9 @@ void ConvolutionComputation:: ConvolveBackwardBias(const MatrixBase &output_deriv, BaseFloat alpha, VectorBase *bias_deriv) const { + if (bias_deriv == nullptr) { + return; + } // TODO. } diff --git a/src/nnet3/convolution-cudnn.h b/src/nnet3/convolution-cudnn.h index 22dc8535a4c..fb9a42affab 100644 --- a/src/nnet3/convolution-cudnn.h +++ b/src/nnet3/convolution-cudnn.h @@ -152,7 +152,7 @@ class ConvolutionComputation final { */ void ConvolveForward(const CuMatrixBase &input, const CuMatrixBase ¶ms, - const CuVectorBase &bias, + const CuVectorBase *bias, CuMatrixBase *output) const; /** @@ -195,7 +195,7 @@ class ConvolutionComputation final { // compile for CUDA or we did but we are not using a GPU. void ConvolveForward(const MatrixBase &input, const MatrixBase ¶ms, - const VectorBase &bias, + const VectorBase *bias, MatrixBase *output) const; void ConvolveBackwardData(const MatrixBase ¶ms, const MatrixBase &output_deriv, From bfea6c8b1cdbdf0d039238c1b22808813bb32c25 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 28 Oct 2018 13:31:57 -0400 Subject: [PATCH 21/22] [egs] Remove unnecessary alignment from mini_librispeech run.sh, thanks: johnjosephmorgan@gmail.com --- egs/mini_librispeech/s5/run.sh | 3 --- 1 file changed, 3 deletions(-) diff --git a/egs/mini_librispeech/s5/run.sh b/egs/mini_librispeech/s5/run.sh index 3ab9d243ef6..997557f7904 100755 --- a/egs/mini_librispeech/s5/run.sh +++ b/egs/mini_librispeech/s5/run.sh @@ -170,9 +170,6 @@ if [ $stage -le 7 ]; then utils/build_const_arpa_lm.sh \ data/local/lm/lm_tglarge.arpa.gz data/lang data/lang_test_tglarge - - steps/align_fmllr.sh --nj 5 --cmd "$train_cmd" \ - data/train_clean_5 data/lang exp/tri3b exp/tri3b_ali_train_clean_5 fi From a2de7b97dbdc6b79354789c544086b72247dec39 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 31 Oct 2018 01:16:17 -0400 Subject: [PATCH 22/22] [src] Small cosmetic changes --- src/ivector/ivector-extractor.h | 2 +- src/nnet3/convolution-cudnn.h | 2 +- src/nnet3/nnet-simple-component.h | 20 +++++++++++++++++--- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/ivector/ivector-extractor.h b/src/ivector/ivector-extractor.h index 3b9b6f3eb5c..938034859e2 100644 --- a/src/ivector/ivector-extractor.h +++ b/src/ivector/ivector-extractor.h @@ -468,7 +468,7 @@ struct IvectorExtractorEstimationOptions { "update any associated parameters."); opts->Register("diagonalize", &diagonalize, "If true, diagonalize the quadratic term in the " - "objective function. This reorders the ivector dimensions" + "objective function. This reorders the ivector dimensions " "from most to least important."); } }; diff --git a/src/nnet3/convolution-cudnn.h b/src/nnet3/convolution-cudnn.h index fb9a42affab..a84bc995b64 100644 --- a/src/nnet3/convolution-cudnn.h +++ b/src/nnet3/convolution-cudnn.h @@ -147,7 +147,7 @@ class ConvolutionComputation final { @param [in] input NWHC fully-packed tensor, with NumRows() == N * W @param [in] params KCWH fully-packed tensor, with NumRows() == K. - @param [in] bias vector of length K + @param [in] pointer bias vector of length K (or NULL if we're not using a bias). @param [out] output Pre-allocated NWHK fully-packed tensor, with NumRows() == N * W. */ void ConvolveForward(const CuMatrixBase &input, diff --git a/src/nnet3/nnet-simple-component.h b/src/nnet3/nnet-simple-component.h index 11c60f8f352..89ad44f50c7 100644 --- a/src/nnet3/nnet-simple-component.h +++ b/src/nnet3/nnet-simple-component.h @@ -1091,9 +1091,23 @@ class SumGroupComponent: public Component { }; -/// FixedScaleComponent applies a fixed per-element scale; it's similar -/// to the Rescale component in the nnet1 setup (and only needed for nnet1 -/// model conversion). +/** + FixedScaleComponent applies a fixed per-element scale; it's similar + to the Rescale component in the nnet1 setup (and only needed for nnet1 + model conversion). + + Configuration values accepted by this component: + scales A filename, e.g. scales=foo/bar/scales.vec. The file should + contain something readable as a Vector; the text form is like: + [ 0.5 0.5 0.2 ] + dim Only accepted if 'scales' is not set, the dimension of the + scale. This is not very useful any more: scales that are the same + for each dimension can now be captured in Descriptors, e.g. + Scale(2.0, some_component_node). + scale If 'dim' is set, the value to which the scale should be + set (will be a constant). Otherwise it will be random (which is + useful only for testing purposes). +*/ class FixedScaleComponent: public Component { public: FixedScaleComponent() { }