Last active
November 10, 2016 15:15
-
-
Save ktnyt/d83f9859264feb19610884618ea3b70c to your computer and use it in GitHub Desktop.
Simple C++ CBLAS wrapper
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| /****************************************************************************** | |
| * | |
| * laplus/laplus.hpp | |
| * | |
| * Copyright (C) 2016 Kotone Itaya | |
| * | |
| * Licensed to the Apache Software Foundation (ASF) under one | |
| * or more contributor license agreements. See the NOTICE file | |
| * distributed with this work for additional information | |
| * regarding copyright ownership. The ASF licenses this file | |
| * to you under the Apache License, Version 2.0 (the | |
| * "License"); you may not use this file except in compliance | |
| * with the License. You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, | |
| * software distributed under the License is distributed on an | |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
| * KIND, either express or implied. See the License for the | |
| * specific language governing permissions and limitations | |
| * under the License. | |
| * | |
| *****************************************************************************/ | |
| #ifndef __LAPLUS__ | |
| #define __LAPLUS__ | |
| #include <cassert> | |
| #include <cstring> | |
| #include <array> | |
| #include <ostream> | |
| #include <algorithm> | |
| #include <functional> | |
| #include <memory> | |
| #include <type_traits> | |
| #include "cblas.h" | |
| namespace laplus { | |
| using Transpose = enum CBLAS_TRANSPOSE; | |
| template<std::size_t Rows, std::size_t Cols=1, bool Trans=false> | |
| class Array; | |
| template<bool Trans> | |
| class Array<1, 1, Trans> { | |
| public: | |
| Array() = delete; | |
| Array(const Array& other) | |
| : shared(other.shared), buffer(other.buffer) | |
| , offset(other.offset), stride(other.shared) {} | |
| Array(Array&& other) noexcept | |
| : shared(other.shared), buffer(other.buffer) | |
| , offset(other.offset), stride(other.shared) | |
| { other.shared = nullptr; } | |
| template<std::size_t Rows> | |
| Array(const Array<Rows>& other, const std::size_t offset, const std::size_t stride) | |
| : shared(other.shared), buffer(other.buffer), offset(offset), stride(stride) {} | |
| void operator=(const float& value) { *(this->buffer + this->offset) = value; } | |
| operator float&() { return *(this->buffer + this->offset); } | |
| friend std::ostream& operator<<(std::ostream& ostream, const Array& array) | |
| { return ostream << *(array.buffer + array.offset); } | |
| private: | |
| std::shared_ptr<float> shared; | |
| float* buffer; | |
| std::size_t offset; | |
| std::size_t stride; | |
| }; | |
| template<std::size_t Rows, std::size_t Cols, bool Trans> | |
| class Array { | |
| template<std::size_t Rows2, std::size_t Cols2, bool Trans2> | |
| friend class Array; | |
| static constexpr std::size_t Size = Rows * Cols; | |
| static constexpr Transpose trans = Trans ? CblasTrans : CblasNoTrans; | |
| public: | |
| // Constructors & Destructor | |
| Array(const float value=0.0) | |
| : shared(std::shared_ptr<float>(new float[Rows*Cols], std::default_delete<float[]>())) | |
| , buffer(shared.get()), offset(0), stride(1) | |
| { for(std::size_t i = 0; i < Rows * Cols; ++i) *(this->buffer + i) = value; } | |
| Array(const std::array<float, Rows*Cols> values) | |
| : shared(std::shared_ptr<float>(new float[Rows*Cols], std::default_delete<float[]>())) | |
| , buffer(shared.get()), offset(0), stride(1) | |
| { std::copy(values.begin(), values.end(), this->buffer); } | |
| Array(const Array& other) | |
| : shared(other.shared), buffer(other.buffer) | |
| , offset(other.offset), stride(other.stride) {} | |
| Array(Array&& other) noexcept | |
| : shared(other.shared), buffer(other.buffer) | |
| , offset(other.offset), stride(other.stride) | |
| { other.shared = nullptr; } | |
| virtual ~Array() {} | |
| // Assignment Operators | |
| Array& operator=(const Array& other) | |
| { | |
| Array another(other); | |
| *this = std::move(another); | |
| return *this; | |
| } | |
| Array& operator=(Array&& other) noexcept | |
| { | |
| swap(*this, other); | |
| return *this; | |
| } | |
| // Member Access Operators | |
| Array<Cols> operator[](const std::size_t index) const | |
| { | |
| std::size_t offset = Trans ? index : Cols * index; | |
| std::size_t stride = Trans ? Rows : 1; | |
| return Array<Cols>(*this, offset, stride); | |
| } | |
| float& operator()(const std::size_t i) const | |
| { return *(this->buffer + i); } | |
| float& operator()(const std::size_t i, const std::size_t j) const | |
| { return *(this->buffer + this->offset + (Trans ? Rows : Cols) * i + j); } | |
| // Miscellaneous Operators | |
| friend std::ostream& operator<<(std::ostream& ostream, const Array& array) | |
| { | |
| ostream << "["; | |
| for(std::size_t i = 0; i < Rows; ++i) { | |
| if(i > 0 && Cols > 1) ostream << std::endl; | |
| if(i > 0) ostream << " "; | |
| if(Cols > 1) ostream << "["; | |
| for(std::size_t j = 0; j < Cols; ++j) { | |
| if(j > 0) ostream << " "; | |
| if(Trans) ostream << array(j, i); | |
| else ostream << array(i, j); | |
| } | |
| if(Cols > 1) ostream << "]"; | |
| } | |
| return ostream << "]"; | |
| } | |
| // Utilities | |
| friend void swap(Array& a, Array& b) | |
| { | |
| std::swap(a.shared, b.shared); | |
| std::swap(a.buffer, b.buffer); | |
| std::swap(a.offset, b.offset); | |
| std::swap(a.stride, b.stride); | |
| } | |
| Array clone() const | |
| { | |
| Array cloned; | |
| cloned.copy(*this); | |
| return cloned; | |
| } | |
| // Accessors | |
| static constexpr std::size_t rows() { return Rows; } | |
| static constexpr std::size_t cols() { return Cols; } | |
| static constexpr std::size_t size() { return Size; } | |
| const std::size_t use_count() const { return shared.use_count(); } | |
| // Level 1 BLAS | |
| void scal(const float alpha) | |
| { cblas_sscal(Size, alpha, this->buffer + this->offset, this->stride); } | |
| void copy(const Array& other) | |
| { cblas_scopy(Size, other.buffer + other.offset, other.stride, | |
| this->buffer + this->offset, this->stride); } | |
| void axpy(const float alpha, const Array& other) | |
| { cblas_saxpy(Size, alpha, other.buffer + other.offset, other.stride, | |
| this->buffer + this->offset, this->stride); } | |
| const float dot(const Array& other) const | |
| { return cblas_sdot(Size, other.buffer + other.offset, other.stride, | |
| this->buffer + this->offset, this->stride); } | |
| const float norm() const | |
| { return cblas_snrm2(Size, this->buffer + this->offset, this->stride); } | |
| const float asum() const | |
| { return cblas_sasum(Size, this->buffer + this->offset, this->stride); } | |
| const float iamax() const | |
| { return cblas_isamax(Size, this->buffer + this->offset, this->stride); } | |
| // Level 2 BLAS | |
| template<std::size_t Rows2, std::size_t Cols2> | |
| void gemv(const float alpha, const Array<Rows2*Cols2, Rows*Cols>& A, | |
| const Array<Rows2, Cols2>& x, const float beta) | |
| { cblas_sgemv(CblasColMajor, A.trans, Rows * Cols, Rows2 * Cols2, alpha, | |
| A.buffer, Rows * Cols, x.buffer + x.offset, x.stride, beta, | |
| this->buffer + this->offset, this->stride); } | |
| template<std::size_t Rows1, std::size_t Cols1, std::size_t Rows2, std::size_t Cols2> | |
| void ger(const float alpha, const Array<Rows1, Cols1>& y, | |
| const Array<Rows2, Cols2>& x) | |
| { cblas_sger(CblasColMajor, Rows2 * Cols2, Rows1 * Cols1, alpha, | |
| x.buffer + x.offset, x.stride, y.buffer + x.offset, y.stride, | |
| this->buffer, Rows2 * Cols2); } | |
| // Level 3 BLAS | |
| template<std::size_t Shared> | |
| void gemm(const float alpha, const Array<Shared, Rows, false>& A, | |
| const Array<Cols, Shared, false>& B, const float beta) | |
| { cblas_sgemm(CblasColMajor, A.trans, B.trans, Rows, Cols, Shared, | |
| alpha, A.buffer, Rows, B.buffer, Shared, | |
| beta, this->buffer, Cols); } | |
| template<std::size_t Shared> | |
| void gemm(const float alpha, const Array<Shared, Rows, true>& A, | |
| const Array<Cols, Shared, false>& B, const float beta) | |
| { cblas_sgemm(CblasColMajor, A.trans, B.trans, Rows, Cols, Shared, | |
| alpha, A.buffer, Shared, B.buffer, Shared, | |
| beta, this->buffer, Cols); } | |
| template<std::size_t Shared> | |
| void gemm(const float alpha, const Array<Shared, Rows, false>& A, | |
| const Array<Cols, Shared, true>& B, const float beta) | |
| { cblas_sgemm(CblasColMajor, A.trans, B.trans, Rows, Cols, Shared, | |
| alpha, A.buffer, Rows, B.buffer, Cols, | |
| beta, this->buffer, Cols); } | |
| template<std::size_t Shared> | |
| void gemm(const float alpha, const Array<Shared, Rows, true>& A, | |
| const Array<Cols, Shared, true>& B, const float beta) | |
| { cblas_sgemm(CblasColMajor, A.trans, B.trans, Rows, Cols, Shared, | |
| alpha, A.buffer, Shared, B.buffer, Cols, | |
| beta, this->buffer, Cols); } | |
| // Extensions | |
| Array(const Array<Cols, Rows, !Trans>& other) | |
| : shared(other.shared), buffer(other.buffer) | |
| , offset(other.offset), stride(other.stride) {} | |
| Array<Cols, Rows, !Trans> transpose() const | |
| { return Array<Cols, Rows, !Trans>(*this); } | |
| template<std::size_t Rows2, std::size_t Cols2> | |
| Array(const Array<Rows2, Cols2, Trans>& other) | |
| : shared(other.shared), buffer(other.buffer) | |
| , offset(other.offset), stride(other.stride) | |
| { static_assert(Rows2*Cols2==Size, "total size of new array must be unchanged"); } | |
| template<std::size_t Rows2, std::size_t Cols2> | |
| Array<Rows2, Cols2, Trans> reshape() const | |
| { return Array<Rows2, Cols2, Trans>(*this); } | |
| template<std::size_t Rows2> | |
| Array(const Array<Rows2, Rows*Cols>& other, | |
| const std::size_t offset, const std::size_t stride) | |
| : shared(other.shared), buffer(other.buffer), offset(offset), stride(stride) | |
| {} | |
| // Linear Algebra | |
| template<std::size_t Shared> | |
| Array<Rows, Shared> dot(const Array<Cols, Shared>& other) const | |
| { | |
| Array<Rows, Shared> result; | |
| result.gemm(1.0, *this, other, 0.0); | |
| return result; | |
| } | |
| private: | |
| std::shared_ptr<float> shared; | |
| float* buffer; | |
| std::size_t offset; | |
| std::size_t stride; | |
| }; | |
| template<std::size_t Rows, std::size_t Cols, std::size_t Shared, bool Trans1, bool Trans2> | |
| Array<Rows, Cols> dot(const Array<Rows, Shared, Trans1>& A, const Array<Shared, Cols, Trans2>& B) | |
| { | |
| Array<Rows, Cols> result; | |
| result.gemm(1.0, B, A, 0.0); | |
| return result; | |
| } | |
| template<std::size_t Rows, std::size_t Cols> | |
| Array<Rows, Cols> outer(const Array<Rows>& x, const Array<Cols>& y) | |
| { | |
| Array<Rows, Cols> result; | |
| result.ger(1.0, x, y); | |
| return result; | |
| } | |
| } | |
| #endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment