33 #include <cudnn_backend.h> 58 ss <<
"CUDNN_BACKEND_POINTWISE_DESCRIPTOR :" 66 case CUDNN_POINTWISE_ADD:
67 case CUDNN_POINTWISE_MUL:
68 case CUDNN_POINTWISE_MIN:
69 case CUDNN_POINTWISE_MAX:
71 case CUDNN_POINTWISE_SQRT:
72 case CUDNN_POINTWISE_RELU_FWD:
73 case CUDNN_POINTWISE_TANH_FWD:
74 case CUDNN_POINTWISE_SIGMOID_FWD:
75 case CUDNN_POINTWISE_ELU_FWD:
104 cudnnPointwiseMode_t
mode = CUDNN_POINTWISE_ADD;
122 m_pointWiseDesc.math_precision = data_type_;
128 m_pointWiseDesc.upper_clip = u;
129 m_pointWiseDesc.lower_clip = l;
135 m_pointWiseDesc.mode = mode_;
141 m_pointWiseDesc.nan_propagation = nan_mode_;
151 auto status = m_pointWiseDesc.initialize_managed_backend_pointer(CUDNN_BACKEND_POINTWISE_DESCRIPTOR);
152 if (
status != CUDNN_STATUS_SUCCESS) {
154 &m_pointWiseDesc,
status,
"CUDNN_BACKEND_POINTWISE_DESCRIPTOR: cudnnCreate Failed");
155 return std::move(m_pointWiseDesc);
159 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
160 CUDNN_ATTR_POINTWISE_MODE,
161 CUDNN_TYPE_POINTWISE_MODE,
163 &m_pointWiseDesc.mode);
164 if (
status != CUDNN_STATUS_SUCCESS) {
168 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: CUDNN_TYPE_POINTWISE_MODE SetAttribute Failed");
169 return std::move(m_pointWiseDesc);
172 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
173 CUDNN_ATTR_POINTWISE_MATH_PREC,
174 CUDNN_TYPE_DATA_TYPE,
176 &m_pointWiseDesc.math_precision);
177 if (
status != CUDNN_STATUS_SUCCESS) {
181 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: SetAttribute CUDNN_ATTR_POINTWISE_MATH_PREC Failed");
182 return std::move(m_pointWiseDesc);
185 if (m_pointWiseDesc.mode == CUDNN_POINTWISE_RELU_FWD) {
186 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
187 CUDNN_ATTR_POINTWISE_NAN_PROPAGATION,
188 CUDNN_TYPE_NAN_PROPOGATION,
190 &m_pointWiseDesc.nan_propagation);
191 if (
status != CUDNN_STATUS_SUCCESS) {
195 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: SetAttribute CUDNN_ATTR_POINTWISE_NAN_PROPAGATION Failed");
196 return std::move(m_pointWiseDesc);
199 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
200 CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP,
203 &m_pointWiseDesc.lower_clip);
204 if (
status != CUDNN_STATUS_SUCCESS) {
208 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: SetAttribute CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP, Failed");
209 return std::move(m_pointWiseDesc);
212 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
213 CUDNN_ATTR_POINTWISE_RELU_UPPER_CLIP,
216 &m_pointWiseDesc.upper_clip);
217 if (
status != CUDNN_STATUS_SUCCESS) {
221 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: SetAttribute CUDNN_ATTR_POINTWISE_RELU_UPPER_CLIP, Failed");
222 return std::move(m_pointWiseDesc);
227 status = cudnnBackendFinalize(m_pointWiseDesc.pointer->get_backend_descriptor());
228 if (
status != CUDNN_STATUS_SUCCESS) {
230 &m_pointWiseDesc,
status,
"CUDNN_BACKEND_POINTWISE_DESCRIPTOR: cudnnFinalize Failed");
231 return std::move(m_pointWiseDesc);
234 return std::move(m_pointWiseDesc);
PointWiseDesc_v8()=default
static void set_error_and_throw_exception(BackendDescriptor const *desc, cudnnStatus_t status, const char *message)
auto setClipping(double l, double u) -> PointWiseDescBuilder_v8 &
Set upper and lower limits for the RELU activation.
PointWiseDesc_v8 & operator=(PointWiseDesc_v8 const &)=delete
auto setMode(cudnnNanPropagation_t nan_mode_) -> PointWiseDescBuilder_v8 &
Set NaN propagation mode.
PointWiseDesc_v8(PointWiseDesc_v8 &&from)
cudnnPointwiseMode_t getPointWiseMode() const
cudnnNanPropagation_t nan_propagation
~PointWiseDesc_v8()=default
ManagedOpaqueDescriptor get_desc() const
Returns a copy of underlying managed descriptor.
friend class PointWiseDescBuilder_v8
std::string describe() const override
Return a string describing the backend Descriptor.
int64_t getPortCount() const
cudnnStatus_t get_status() const
Current status of the descriptor.
PointWiseDesc_v8 m_pointWiseDesc
const char * get_error() const
Diagonistic error message if any.
cudnnDataType_t math_precision
auto setMathPrecision(cudnnDataType_t data_type_) -> PointWiseDescBuilder_v8 &
Set Math Precision Data Type for the Convolution Operation.
auto setMode(cudnnPointwiseMode_t mode_) -> PointWiseDescBuilder_v8 &
Set upper and lower limits for the RELU activation.
cudnnPointwiseMode_t mode
cudnnStatus_t status
Shared pointer of the OpaqueBackendPointer.
PointWiseDesc_v8 && build()