|
| 1 | +//===- COO.h - Coordinate-scheme sparse tensor representation ---*- C++ -*-===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// This file is part of the lightweight runtime support library for sparse |
| 10 | +// tensor manipulations. The functionality of the support library is meant |
| 11 | +// to simplify benchmarking, testing, and debugging MLIR code operating on |
| 12 | +// sparse tensors. However, the provided functionality is **not** part of |
| 13 | +// core MLIR itself. |
| 14 | +// |
| 15 | +//===----------------------------------------------------------------------===// |
| 16 | + |
| 17 | +#ifndef MLIR_EXECUTIONENGINE_SPARSETENSOR_COO_H |
| 18 | +#define MLIR_EXECUTIONENGINE_SPARSETENSOR_COO_H |
| 19 | + |
| 20 | +#include <algorithm> |
| 21 | +#include <cassert> |
| 22 | +#include <cinttypes> |
| 23 | +#include <functional> |
| 24 | +#include <vector> |
| 25 | + |
| 26 | +namespace mlir { |
| 27 | +namespace sparse_tensor { |
| 28 | + |
| 29 | +/// A sparse tensor element in coordinate scheme (value and indices). |
| 30 | +/// For example, a rank-1 vector element would look like |
| 31 | +/// ({i}, a[i]) |
| 32 | +/// and a rank-5 tensor element like |
| 33 | +/// ({i,j,k,l,m}, a[i,j,k,l,m]) |
| 34 | +/// We use pointer to a shared index pool rather than e.g. a direct |
| 35 | +/// vector since that (1) reduces the per-element memory footprint, and |
| 36 | +/// (2) centralizes the memory reservation and (re)allocation to one place. |
| 37 | +template <typename V> |
| 38 | +struct Element final { |
| 39 | + Element(uint64_t *ind, V val) : indices(ind), value(val){}; |
| 40 | + uint64_t *indices; // pointer into shared index pool |
| 41 | + V value; |
| 42 | +}; |
| 43 | + |
| 44 | +/// The type of callback functions which receive an element. We avoid |
| 45 | +/// packaging the coordinates and value together as an `Element` object |
| 46 | +/// because this helps keep code somewhat cleaner. |
| 47 | +template <typename V> |
| 48 | +using ElementConsumer = |
| 49 | + const std::function<void(const std::vector<uint64_t> &, V)> &; |
| 50 | + |
| 51 | +/// A memory-resident sparse tensor in coordinate scheme (collection of |
| 52 | +/// elements). This data structure is used to read a sparse tensor from |
| 53 | +/// any external format into memory and sort the elements lexicographically |
| 54 | +/// by indices before passing it back to the client (most packed storage |
| 55 | +/// formats require the elements to appear in lexicographic index order). |
| 56 | +template <typename V> |
| 57 | +struct SparseTensorCOO final { |
| 58 | +public: |
| 59 | + SparseTensorCOO(const std::vector<uint64_t> &dimSizes, uint64_t capacity) |
| 60 | + : dimSizes(dimSizes) { |
| 61 | + if (capacity) { |
| 62 | + elements.reserve(capacity); |
| 63 | + indices.reserve(capacity * getRank()); |
| 64 | + } |
| 65 | + } |
| 66 | + |
| 67 | + /// Adds element as indices and value. |
| 68 | + void add(const std::vector<uint64_t> &ind, V val) { |
| 69 | + assert(!iteratorLocked && "Attempt to add() after startIterator()"); |
| 70 | + uint64_t *base = indices.data(); |
| 71 | + uint64_t size = indices.size(); |
| 72 | + uint64_t rank = getRank(); |
| 73 | + assert(ind.size() == rank && "Element rank mismatch"); |
| 74 | + for (uint64_t r = 0; r < rank; r++) { |
| 75 | + assert(ind[r] < dimSizes[r] && "Index is too large for the dimension"); |
| 76 | + indices.push_back(ind[r]); |
| 77 | + } |
| 78 | + // This base only changes if indices were reallocated. In that case, we |
| 79 | + // need to correct all previous pointers into the vector. Note that this |
| 80 | + // only happens if we did not set the initial capacity right, and then only |
| 81 | + // for every internal vector reallocation (which with the doubling rule |
| 82 | + // should only incur an amortized linear overhead). |
| 83 | + uint64_t *newBase = indices.data(); |
| 84 | + if (newBase != base) { |
| 85 | + for (uint64_t i = 0, n = elements.size(); i < n; i++) |
| 86 | + elements[i].indices = newBase + (elements[i].indices - base); |
| 87 | + base = newBase; |
| 88 | + } |
| 89 | + // Add element as (pointer into shared index pool, value) pair. |
| 90 | + elements.emplace_back(base + size, val); |
| 91 | + } |
| 92 | + |
| 93 | + /// Sorts elements lexicographically by index. |
| 94 | + void sort() { |
| 95 | + assert(!iteratorLocked && "Attempt to sort() after startIterator()"); |
| 96 | + // TODO: we may want to cache an `isSorted` bit, to avoid |
| 97 | + // unnecessary/redundant sorting. |
| 98 | + uint64_t rank = getRank(); |
| 99 | + std::sort(elements.begin(), elements.end(), |
| 100 | + [rank](const Element<V> &e1, const Element<V> &e2) { |
| 101 | + for (uint64_t r = 0; r < rank; r++) { |
| 102 | + if (e1.indices[r] == e2.indices[r]) |
| 103 | + continue; |
| 104 | + return e1.indices[r] < e2.indices[r]; |
| 105 | + } |
| 106 | + return false; |
| 107 | + }); |
| 108 | + } |
| 109 | + |
| 110 | + /// Get the rank of the tensor. |
| 111 | + uint64_t getRank() const { return dimSizes.size(); } |
| 112 | + |
| 113 | + /// Getter for the dimension-sizes array. |
| 114 | + const std::vector<uint64_t> &getDimSizes() const { return dimSizes; } |
| 115 | + |
| 116 | + /// Getter for the elements array. |
| 117 | + const std::vector<Element<V>> &getElements() const { return elements; } |
| 118 | + |
| 119 | + /// Switch into iterator mode. |
| 120 | + void startIterator() { |
| 121 | + iteratorLocked = true; |
| 122 | + iteratorPos = 0; |
| 123 | + } |
| 124 | + |
| 125 | + /// Get the next element. |
| 126 | + const Element<V> *getNext() { |
| 127 | + assert(iteratorLocked && "Attempt to getNext() before startIterator()"); |
| 128 | + if (iteratorPos < elements.size()) |
| 129 | + return &(elements[iteratorPos++]); |
| 130 | + iteratorLocked = false; |
| 131 | + return nullptr; |
| 132 | + } |
| 133 | + |
| 134 | + /// Factory method. Permutes the original dimensions according to |
| 135 | + /// the given ordering and expects subsequent add() calls to honor |
| 136 | + /// that same ordering for the given indices. The result is a |
| 137 | + /// fully permuted coordinate scheme. |
| 138 | + /// |
| 139 | + /// Precondition: `dimSizes` and `perm` must be valid for `rank`. |
| 140 | + static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank, |
| 141 | + const uint64_t *dimSizes, |
| 142 | + const uint64_t *perm, |
| 143 | + uint64_t capacity = 0) { |
| 144 | + std::vector<uint64_t> permsz(rank); |
| 145 | + for (uint64_t r = 0; r < rank; r++) { |
| 146 | + assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage"); |
| 147 | + permsz[perm[r]] = dimSizes[r]; |
| 148 | + } |
| 149 | + return new SparseTensorCOO<V>(permsz, capacity); |
| 150 | + } |
| 151 | + |
| 152 | +private: |
| 153 | + const std::vector<uint64_t> dimSizes; // per-dimension sizes |
| 154 | + std::vector<Element<V>> elements; // all COO elements |
| 155 | + std::vector<uint64_t> indices; // shared index pool |
| 156 | + bool iteratorLocked = false; |
| 157 | + unsigned iteratorPos = 0; |
| 158 | +}; |
| 159 | + |
| 160 | +} // namespace sparse_tensor |
| 161 | +} // namespace mlir |
| 162 | + |
| 163 | +#endif // MLIR_EXECUTIONENGINE_SPARSETENSOR_COO_H |
0 commit comments