CUDNN API  8
cudnn_frontend_VariantPack.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice shall be included in
12  * all copies or substantial portions of the Software.
13  *
14  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20  * DEALINGS IN THE SOFTWARE.
21  */
22 
23 #pragma once
24 
25 #include <algorithm>
26 #include <array>
27 #include <functional>
28 #include <memory>
29 #include <set>
30 #include <sstream>
31 #include <utility>
32 
33 #include <cudnn.h>
34 #include <cudnn_backend.h>
35 
36 #include "cudnn_frontend_utils.h"
37 
38 namespace cudnn_frontend {
39 
53  public:
54  friend class VariantPackBuilder_v8;
55  std::string
56  describe() const override {
57  std::stringstream ss;
58  ss << "CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR :"
59  << " has " << num_ptrs << " data pointers";
60  return ss.str();
61  }
63  : BackendDescriptor(from.get_desc(), from.get_status(), from.get_error()),
64  workspace(from.workspace),
65  num_ptrs(from.num_ptrs) {
66  std::copy(std::begin(from.data_pointers), std::end(from.data_pointers), data_pointers);
67  std::copy(std::begin(from.uid), std::end(from.uid), uid);
68  }
69  ~VariantPack_v8() = default;
70 
71  private:
72  VariantPack_v8() = default;
73  VariantPack_v8(VariantPack_v8 const &) = delete;
75  operator=(VariantPack_v8 const &) = delete;
76 
77  void *workspace = nullptr;
78  void *data_pointers[10] = {nullptr};
79  int64_t uid[10] = {-1};
80  int64_t num_ptrs = -1;
81 };
82 
87  public:
92  auto
94  setDataPointers(int64_t num_ptr, void **ptrs) -> VariantPackBuilder_v8 & {
95  std::copy(ptrs, ptrs + num_ptr, m_variant_pack.data_pointers);
96  m_variant_pack.num_ptrs = num_ptr;
97  return *this;
98  }
100  auto
101  setUids(int64_t num_uids, int64_t *uid) -> VariantPackBuilder_v8 & {
102  std::copy(uid, uid + num_uids, m_variant_pack.uid);
103  return *this;
104  }
106  auto
107  setDataPointers(std::set<std::pair<uint64_t, void *>> const &data_pointers) -> VariantPackBuilder_v8 & {
108  auto i = 0;
109  for (auto &data_pointer : data_pointers) {
110  m_variant_pack.uid[i] = data_pointer.first;
111  m_variant_pack.data_pointers[i] = data_pointer.second;
112  i++;
113  }
114  m_variant_pack.num_ptrs = data_pointers.size();
115  return *this;
116  }
118  auto
120  m_variant_pack.workspace = ws;
121  return *this;
122  }
125  VariantPack_v8 &&
128  build() {
129  // Create a descriptor. Memory allocation happens here.
130  auto status = m_variant_pack.initialize_managed_backend_pointer(CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR);
131  if (status != CUDNN_STATUS_SUCCESS) {
133  &m_variant_pack, status, "CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR: cudnnCreate Failed");
134  return std::move(m_variant_pack);
135  }
136 
137  status = cudnnBackendSetAttribute(m_variant_pack.pointer->get_backend_descriptor(),
138  CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS,
139  CUDNN_TYPE_VOID_PTR,
140  m_variant_pack.num_ptrs,
141  m_variant_pack.data_pointers);
142  if (status != CUDNN_STATUS_SUCCESS) {
144  &m_variant_pack,
145  status,
146  "CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR: SetAttribute CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS Failed");
147  return std::move(m_variant_pack);
148  }
149 
150  status = cudnnBackendSetAttribute(m_variant_pack.pointer->get_backend_descriptor(),
151  CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS,
152  CUDNN_TYPE_INT64,
153  m_variant_pack.num_ptrs,
154  m_variant_pack.uid);
155  if (status != CUDNN_STATUS_SUCCESS) {
157  &m_variant_pack,
158  status,
159  "CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR: SetAttribute CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS Failed");
160  return std::move(m_variant_pack);
161  }
162 
163  status = cudnnBackendSetAttribute(m_variant_pack.pointer->get_backend_descriptor(),
164  CUDNN_ATTR_VARIANT_PACK_WORKSPACE,
165  CUDNN_TYPE_VOID_PTR,
166  1,
167  &m_variant_pack.workspace);
168  if (status != CUDNN_STATUS_SUCCESS) {
170  &m_variant_pack,
171  status,
172  "CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR: SetAttribute CUDNN_ATTR_VARIANT_PACK_WORKSPACE Failed");
173  return std::move(m_variant_pack);
174  }
175 
176  // Finalizing the descriptor
177  status = cudnnBackendFinalize(m_variant_pack.pointer->get_backend_descriptor());
178  if (status != CUDNN_STATUS_SUCCESS) {
180  &m_variant_pack, status, "CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR: cudnnFinalize Failed");
181  return std::move(m_variant_pack);
182  }
183  return std::move(m_variant_pack);
184  }
185 
186  explicit VariantPackBuilder_v8() = default;
187  ~VariantPackBuilder_v8() = default;
191  operator=(VariantPackBuilder_v8 const &) = delete;
192 
193  private:
195 };
196 }
static void set_error_and_throw_exception(BackendDescriptor const *desc, cudnnStatus_t status, const char *message)
auto setWorkspacePointer(void *ws) -> VariantPackBuilder_v8 &
Set Workspace.
auto setDataPointers(int64_t num_ptr, void **ptrs) -> VariantPackBuilder_v8 &
Set dataPointers for the VariantPack_v8.
VariantPack_v8 & operator=(VariantPack_v8 const &)=delete
ManagedOpaqueDescriptor get_desc() const
Returns a copy of underlying managed descriptor.
auto setUids(int64_t num_uids, int64_t *uid) -> VariantPackBuilder_v8 &
Set Uids for the VariantPack_v8.
cudnnStatus_t get_status() const
Current status of the descriptor.
const char * get_error() const
Diagonistic error message if any.
auto setDataPointers(std::set< std::pair< uint64_t, void *>> const &data_pointers) -> VariantPackBuilder_v8 &
Initialize a set of pairs containing uid and data pointer.
std::string describe() const override
Return a string describing the backend Descriptor.
cudnnStatus_t status
Shared pointer of the OpaqueBackendPointer.