CUDNN API  8
cudnn_frontend_Operation.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice shall be included in
12  * all copies or substantial portions of the Software.
13  *
14  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20  * DEALINGS IN THE SOFTWARE.
21  */
22 
23 #pragma once
24 
25 #include <algorithm>
26 #include <array>
27 #include <functional>
28 #include <memory>
29 #include <sstream>
30 #include <utility>
31 
32 #include <cudnn.h>
33 #include <cudnn_backend.h>
34 
37 #include "cudnn_frontend_Tensor.h"
38 #include "cudnn_frontend_utils.h"
39 
40 namespace cudnn_frontend {
41 
57  public:
58  friend class OperationBuilder_v8;
59  std::string
60  describe() const override {
61  std::stringstream ss;
62  ss << "CUDNN_BACKEND_OPERATION :"
63  << " OpMode: " << std::to_string(op_mode);
64  ss << std::hex << " X " << xdesc;
65  ss << std::hex << " Y " << ydesc;
66  ss << std::hex << " W " << wdesc;
67  ss << std::hex << " B " << bdesc;
68  ss << std::hex << " C " << cdesc;
69  ss << std::hex << " P " << pwdesc;
70  ss << std::dec << " alphabetaType " << alphabetaType;
71  ss << " Alpha: " << alpha_s << " " << alpha_d;
72  ss << " Alpha2: " << alpha2_s << " " << alpha2_d;
73  ss << " Beta: " << beta_s << " " << beta_d;
74  return ss.str();
75  }
77  : BackendDescriptor(from.pointer, from.get_status(), from.get_error()),
78  op_mode(from.op_mode),
79  xdesc(from.xdesc),
80  ydesc(from.ydesc),
81  wdesc(from.wdesc),
82  bdesc(from.bdesc),
83  cdesc(from.cdesc),
84  pwdesc(from.pwdesc),
86  alpha_s(from.alpha_s),
87  alpha_d(from.alpha_d),
88  beta_s(from.beta_s),
89  beta_d(from.beta_d),
92  operationTag(from.operationTag) {}
93 
96  return ydesc;
97  }
98 
99  std::string const &
100  getTag() const {
101  return operationTag;
102  }
103 
104  ~Operation_v8() = default;
105 
106  private:
107  Operation_v8() = default;
108  Operation_v8(Operation_v8 const &) = delete;
109  Operation_v8 &
110  operator=(Operation_v8 const &) = delete;
111 
112  cudnnBackendDescriptorType_t op_mode = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
113 
120 
121  cudnnBackendAttributeType_t alphabetaType = CUDNN_TYPE_FLOAT;
122  float alpha_s = 1.0f, beta_s = .0f, alpha2_s = 1.0f;
123  double alpha_d = 1.0, beta_d = 0.0, alpha2_d = 1.0;
124  int64_t pointwise_port_count = -1;
125  cudnnPointwiseMode_t pointwise_mode;
126  std::string operationTag;
127 };
128 
132 
134  private:
136  bool is_convolution_op = false;
137 
138  public:
143  auto
145  m_operation.xdesc = raw_tensor;
146  return *this;
147  }
148 
149  auto
150  setxDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 & {
151  m_operation.xdesc = tensor.get_desc();
152  return *this;
153  }
154  auto
155  setbDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 & {
156  if (is_convolution_op == true) {
158  &m_operation,
159  CUDNN_STATUS_BAD_PARAM,
160  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Convolution operation does not need bTensor");
161  }
162  m_operation.bdesc = tensor.get_desc();
163  return *this;
164  }
165  auto
166  setyDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 & {
167  m_operation.ydesc = tensor.get_desc();
168  return *this;
169  }
170  auto
171  setwDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 & {
172  if (is_convolution_op == false) {
174  &m_operation,
175  CUDNN_STATUS_BAD_PARAM,
176  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Convolution operation does not need wTensor");
177  }
178  m_operation.wdesc = tensor.get_desc();
179  return *this;
180  }
181  auto
183  if (is_convolution_op == false) {
185  &m_operation,
186  CUDNN_STATUS_BAD_PARAM,
187  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Convolution operation does not need Convolution DESCRIPTOR");
188  }
189  m_operation.cdesc = conv.get_desc();
190  return *this;
191  }
192  auto
193  setpwDesc(PointWiseDesc_v8 const &pointWiseDesc) -> OperationBuilder_v8 & {
194  if (is_convolution_op == true) {
196  &m_operation,
197  CUDNN_STATUS_BAD_PARAM,
198  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Convolution operation does not need POINTWISE DESCRIPTOR");
199  }
200  m_operation.pwdesc = pointWiseDesc.get_desc();
201  m_operation.pointwise_port_count = pointWiseDesc.getPortCount();
202  m_operation.pointwise_mode = pointWiseDesc.getPointWiseMode();
203  return *this;
204  }
205  auto
206  setAlpha(float alpha) -> OperationBuilder_v8 & {
207  m_operation.alphabetaType = CUDNN_TYPE_FLOAT;
208  m_operation.alpha_d = static_cast<double>(alpha);
209  m_operation.alpha_s = alpha;
210  return *this;
211  }
212  auto
213  setAlpha(double alpha) -> OperationBuilder_v8 & {
214  m_operation.alphabetaType = CUDNN_TYPE_DOUBLE;
215  m_operation.alpha_s = static_cast<float>(alpha);
216  m_operation.alpha_d = alpha;
217  return *this;
218  }
219  auto
220  setAlpha2(float alpha) -> OperationBuilder_v8 & {
221  m_operation.alphabetaType = CUDNN_TYPE_FLOAT;
222  m_operation.alpha2_d = static_cast<double>(alpha);
223  m_operation.alpha2_s = alpha;
224  return *this;
225  }
226  auto
227  setAlpha2(double alpha) -> OperationBuilder_v8 & {
228  m_operation.alphabetaType = CUDNN_TYPE_DOUBLE;
229  m_operation.alpha2_s = static_cast<float>(alpha);
230  m_operation.alpha2_d = alpha;
231  return *this;
232  }
233  auto
234  setBeta(float beta) -> OperationBuilder_v8 & {
235  m_operation.alphabetaType = CUDNN_TYPE_FLOAT;
236  m_operation.beta_d = static_cast<double>(beta);
237  m_operation.beta_s = beta;
238  return *this;
239  }
240  auto
241  setBeta(double beta) -> OperationBuilder_v8 & {
242  m_operation.alphabetaType = CUDNN_TYPE_DOUBLE;
243  m_operation.beta_s = static_cast<float>(beta);
244  m_operation.beta_d = beta;
245  return *this;
246  }
247 
248  OperationBuilder_v8(cudnnBackendDescriptorType_t mode) {
249  m_operation.op_mode = mode;
250  is_convolution_op = ((m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) ||
251  (m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR) ||
252  (m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR));
253  }
256  Operation_v8 &&
259  build() {
260  if (m_operation.status != CUDNN_STATUS_SUCCESS) {
262  &m_operation, m_operation.status, "CUDNN_BACKEND_OPERATION: Operation not initialized properly");
263  return std::move(m_operation);
264  }
265  if (m_operation.xdesc == nullptr) {
267  &m_operation,
268  CUDNN_STATUS_BAD_PARAM,
269  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_X");
270  return std::move(m_operation);
271  }
272  if (m_operation.wdesc == nullptr && is_convolution_op) {
274  &m_operation,
275  CUDNN_STATUS_BAD_PARAM,
276  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_W");
277  return std::move(m_operation);
278  }
279  if (m_operation.ydesc == nullptr && is_convolution_op) {
281  &m_operation,
282  CUDNN_STATUS_BAD_PARAM,
283  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_Y");
284  return std::move(m_operation);
285  }
286  if (m_operation.cdesc == nullptr && is_convolution_op) {
288  &m_operation,
289  CUDNN_STATUS_BAD_PARAM,
290  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_CONV_DESC");
291  return std::move(m_operation);
292  }
293 
294  // Create the descriptor.
295  auto status = m_operation.initialize_managed_backend_pointer(m_operation.op_mode);
296  if (status != CUDNN_STATUS_SUCCESS) {
297  set_error_and_throw_exception(&m_operation, status, "CUDNN_BACKEND_OPERATION: cudnnCreate Failed");
298  return std::move(m_operation);
299  }
300  if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) {
301  m_operation.operationTag = "ConvFwd";
302 
303  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
304  CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X,
305  CUDNN_TYPE_BACKEND_DESCRIPTOR,
306  1,
307  &(m_operation.xdesc->get_backend_descriptor()));
308  if (status != CUDNN_STATUS_SUCCESS) {
310  &m_operation,
311  status,
312  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X Failed");
313  return std::move(m_operation);
314  }
315  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
316  CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W,
317  CUDNN_TYPE_BACKEND_DESCRIPTOR,
318  1,
319  &(m_operation.wdesc->get_backend_descriptor()));
320  if (status != CUDNN_STATUS_SUCCESS) {
322  &m_operation,
323  status,
324  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W Failed");
325  return std::move(m_operation);
326  }
327  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
328  CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y,
329  CUDNN_TYPE_BACKEND_DESCRIPTOR,
330  1,
331  &(m_operation.ydesc->get_backend_descriptor()));
332  if (status != CUDNN_STATUS_SUCCESS) {
334  &m_operation,
335  status,
336  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y Failed");
337  return std::move(m_operation);
338  }
339  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
340  CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC,
341  CUDNN_TYPE_BACKEND_DESCRIPTOR,
342  1,
343  &(m_operation.cdesc->get_backend_descriptor()));
344  if (status != CUDNN_STATUS_SUCCESS) {
346  &m_operation,
347  status,
348  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC Failed");
349  return std::move(m_operation);
350  }
351  void *alpha = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.alpha_s)
352  : static_cast<void *>(&m_operation.alpha_d));
353  void *beta = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.beta_s)
354  : static_cast<void *>(&m_operation.beta_d));
355  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
356  CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA,
357  m_operation.alphabetaType,
358  1,
359  alpha);
360  if (status != CUDNN_STATUS_SUCCESS) {
362  &m_operation,
363  status,
364  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA Failed");
365  return std::move(m_operation);
366  }
367  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
368  CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA,
369  m_operation.alphabetaType,
370  1,
371  beta);
372  if (status != CUDNN_STATUS_SUCCESS) {
374  &m_operation,
375  status,
376  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA Failed");
377  return std::move(m_operation);
378  }
379  } else if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR) {
380  m_operation.operationTag = "ConvBwdFilter";
381 
382  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
383  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X,
384  CUDNN_TYPE_BACKEND_DESCRIPTOR,
385  1,
386  &(m_operation.xdesc->get_backend_descriptor()));
387  if (status != CUDNN_STATUS_SUCCESS) {
389  &m_operation,
390  status,
391  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X Failed");
392  return std::move(m_operation);
393  }
394  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
395  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW,
396  CUDNN_TYPE_BACKEND_DESCRIPTOR,
397  1,
398  &(m_operation.wdesc->get_backend_descriptor()));
399  if (status != CUDNN_STATUS_SUCCESS) {
401  &m_operation,
402  status,
403  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW Failed");
404  return std::move(m_operation);
405  }
406  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
407  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY,
408  CUDNN_TYPE_BACKEND_DESCRIPTOR,
409  1,
410  &(m_operation.ydesc->get_backend_descriptor()));
411  if (status != CUDNN_STATUS_SUCCESS) {
413  &m_operation,
414  status,
415  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY Failed");
416  return std::move(m_operation);
417  }
418  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
419  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC,
420  CUDNN_TYPE_BACKEND_DESCRIPTOR,
421  1,
422  &(m_operation.cdesc->get_backend_descriptor()));
423  if (status != CUDNN_STATUS_SUCCESS) {
424  set_error_and_throw_exception(&m_operation,
425  status,
426  "CUDNN_BACKEND_OPERATION: SetAttribute "
427  "CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC Failed");
428  return std::move(m_operation);
429  }
430  void *alpha = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.alpha_s)
431  : static_cast<void *>(&m_operation.alpha_d));
432  void *beta = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.beta_s)
433  : static_cast<void *>(&m_operation.beta_d));
434  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
435  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA,
436  m_operation.alphabetaType,
437  1,
438  alpha);
439  if (status != CUDNN_STATUS_SUCCESS) {
441  &m_operation,
442  status,
443  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA Failed");
444  return std::move(m_operation);
445  }
446  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
447  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA,
448  m_operation.alphabetaType,
449  1,
450  beta);
451  if (status != CUDNN_STATUS_SUCCESS) {
453  &m_operation,
454  status,
455  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA Failed");
456  return std::move(m_operation);
457  }
458  } else if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) {
459  m_operation.operationTag = "ConvBwdData";
460 
461  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
462  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX,
463  CUDNN_TYPE_BACKEND_DESCRIPTOR,
464  1,
465  &(m_operation.xdesc->get_backend_descriptor()));
466  if (status != CUDNN_STATUS_SUCCESS) {
468  &m_operation,
469  status,
470  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX Failed");
471  return std::move(m_operation);
472  }
473  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
474  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W,
475  CUDNN_TYPE_BACKEND_DESCRIPTOR,
476  1,
477  &(m_operation.wdesc->get_backend_descriptor()));
478  if (status != CUDNN_STATUS_SUCCESS) {
480  &m_operation,
481  status,
482  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W Failed");
483  return std::move(m_operation);
484  }
485  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
486  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY,
487  CUDNN_TYPE_BACKEND_DESCRIPTOR,
488  1,
489  &(m_operation.ydesc->get_backend_descriptor()));
490  if (status != CUDNN_STATUS_SUCCESS) {
492  &m_operation,
493  status,
494  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY Failed");
495  return std::move(m_operation);
496  }
497  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
498  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC,
499  CUDNN_TYPE_BACKEND_DESCRIPTOR,
500  1,
501  &(m_operation.cdesc->get_backend_descriptor()));
502  if (status != CUDNN_STATUS_SUCCESS) {
504  &m_operation,
505  status,
506  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC Failed");
507  return std::move(m_operation);
508  }
509  void *alpha = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.alpha_s)
510  : static_cast<void *>(&m_operation.alpha_d));
511  void *beta = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.beta_s)
512  : static_cast<void *>(&m_operation.beta_d));
513  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
514  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA,
515  m_operation.alphabetaType,
516  1,
517  alpha);
518  if (status != CUDNN_STATUS_SUCCESS) {
520  &m_operation,
521  status,
522  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA Failed");
523  return std::move(m_operation);
524  }
525  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
526  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA,
527  m_operation.alphabetaType,
528  1,
529  beta);
530  if (status != CUDNN_STATUS_SUCCESS) {
532  &m_operation,
533  status,
534  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA Failed");
535  return std::move(m_operation);
536  }
537  } else if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) {
538  switch (m_operation.pointwise_mode) {
539  case CUDNN_POINTWISE_ADD:
540  m_operation.operationTag = "Add";
541  break;
542  case CUDNN_POINTWISE_MUL:
543  m_operation.operationTag = "Mul";
544  break;
545  case CUDNN_POINTWISE_MIN:
546  m_operation.operationTag = "Min";
547  break;
548  case CUDNN_POINTWISE_MAX:
549  m_operation.operationTag = "Max";
550  break;
551  case CUDNN_POINTWISE_SQRT:
552  m_operation.operationTag = "Sqrt";
553  break;
554  case CUDNN_POINTWISE_RELU_FWD:
555  m_operation.operationTag = "ReluFwd";
556  break;
557  case CUDNN_POINTWISE_TANH_FWD:
558  m_operation.operationTag = "TanhFwd";
559  break;
560  case CUDNN_POINTWISE_SIGMOID_FWD:
561  m_operation.operationTag = "SigmoidFwd";
562  break;
563  case CUDNN_POINTWISE_ELU_FWD:
564  m_operation.operationTag = "EluFwd";
565  break;
566  }
567 
568  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
569  CUDNN_ATTR_OPERATION_POINTWISE_XDESC,
570  CUDNN_TYPE_BACKEND_DESCRIPTOR,
571  1,
572  &(m_operation.xdesc->get_backend_descriptor()));
573  if (status != CUDNN_STATUS_SUCCESS) {
575  &m_operation,
576  status,
577  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_XDESC Failed");
578  return std::move(m_operation);
579  }
580  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
581  CUDNN_ATTR_OPERATION_POINTWISE_PW_DESCRIPTOR,
582  CUDNN_TYPE_BACKEND_DESCRIPTOR,
583  1,
584  &(m_operation.pwdesc->get_backend_descriptor()));
585  if (status != CUDNN_STATUS_SUCCESS) {
587  &m_operation,
588  status,
589  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_PW_DESCRIPTOR Failed");
590  return std::move(m_operation);
591  }
592  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
593  CUDNN_ATTR_OPERATION_POINTWISE_YDESC,
594  CUDNN_TYPE_BACKEND_DESCRIPTOR,
595  1,
596  &(m_operation.ydesc->get_backend_descriptor()));
597  if (status != CUDNN_STATUS_SUCCESS) {
599  &m_operation,
600  status,
601  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_YDESC Failed");
602  return std::move(m_operation);
603  }
604  void *alpha = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.alpha_s)
605  : static_cast<void *>(&m_operation.alpha_d));
606  void *alpha2 = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.alpha2_s)
607  : static_cast<void *>(&m_operation.alpha2_d));
608  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
609  CUDNN_ATTR_OPERATION_POINTWISE_ALPHA1,
610  m_operation.alphabetaType,
611  1,
612  alpha);
613  if (status != CUDNN_STATUS_SUCCESS) {
615  &m_operation,
616  status,
617  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_ALPHA1 Failed");
618  return std::move(m_operation);
619  }
620  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
621  CUDNN_ATTR_OPERATION_POINTWISE_ALPHA2,
622  m_operation.alphabetaType,
623  1,
624  alpha2);
625  if (status != CUDNN_STATUS_SUCCESS) {
627  &m_operation,
628  status,
629  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_ALPHA2 Failed");
630  return std::move(m_operation);
631  }
632  if (m_operation.pointwise_port_count == 3) {
633  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
634  CUDNN_ATTR_OPERATION_POINTWISE_BDESC,
635  CUDNN_TYPE_BACKEND_DESCRIPTOR,
636  1,
637  &(m_operation.bdesc->get_backend_descriptor()));
638  if (status != CUDNN_STATUS_SUCCESS) {
640  &m_operation,
641  status,
642  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_BDESC Failed");
643  return std::move(m_operation);
644  }
645  }
646  }
647  status = cudnnBackendFinalize(m_operation.pointer->get_backend_descriptor());
648  if (status != CUDNN_STATUS_SUCCESS) {
649  set_error_and_throw_exception(&m_operation, status, "CUDNN_BACKEND_OPERATION: cudnnFinalize Failed");
650  return std::move(m_operation);
651  }
652  return std::move(m_operation);
653  }
654 };
655 }
auto setcDesc(ConvDesc_v8 const &conv) -> OperationBuilder_v8 &
cudnnStatus_t initialize_managed_backend_pointer(cudnnBackendDescriptorType_t type)
Initializes the underlying managed descriptor.
static void set_error_and_throw_exception(BackendDescriptor const *desc, cudnnStatus_t status, const char *message)
auto setAlpha(float alpha) -> OperationBuilder_v8 &
auto setwDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
Operation_v8 & operator=(Operation_v8 const &)=delete
auto setbDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
cudnnBackendDescriptorType_t op_mode
auto setBeta(float beta) -> OperationBuilder_v8 &
auto setpwDesc(PointWiseDesc_v8 const &pointWiseDesc) -> OperationBuilder_v8 &
auto setAlpha2(float alpha) -> OperationBuilder_v8 &
cudnnStatus_t get_status() const
Current status of the descriptor.
auto setBeta(double beta) -> OperationBuilder_v8 &
std::shared_ptr< OpaqueBackendPointer > ManagedOpaqueDescriptor
std::string describe() const override
Return a string describing the backend Descriptor.
const char * get_error() const
Diagonistic error message if any.
cudnnBackendAttributeType_t alphabetaType
auto setxDesc(ManagedOpaqueDescriptor const &raw_tensor) -> OperationBuilder_v8 &
auto setyDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
auto setxDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
ManagedOpaqueDescriptor getOutputTensor()
auto setAlpha2(double alpha) -> OperationBuilder_v8 &
OperationBuilder_v8(cudnnBackendDescriptorType_t mode)
std::string const & getTag() const
auto setAlpha(double alpha) -> OperationBuilder_v8 &
cudnnStatus_t status
Shared pointer of the OpaqueBackendPointer.