Created
December 17, 2018 14:49
-
-
Save wolfv/efa7da7e2f1d9b23f2657d934333b01c to your computer and use it in GitHub Desktop.
xtensor sparse experiments
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/*************************************************************************** | |
* Copyright (c) 2016, Johan Mabille, Sylvain Corlay and Wolf Vollprecht * | |
* * | |
* Distributed under the terms of the BSD 3-Clause License. * | |
* * | |
* The full license is in the file LICENSE, distributed with this software. * | |
****************************************************************************/ | |
#ifndef XTENSOR_SPARSE_HPP | |
#define XTENSOR_SPARSE_HPP | |
#include <xtensor/xexpression.hpp> | |
#include <xtensor/xassign.hpp> | |
#include <map> | |
namespace xt | |
{ | |
template <class D> | |
class xsparse; | |
template <class D> | |
struct xcontainer_inner_types<xsparse<D>> | |
{ | |
using xexpression_type = xsparse<D>; | |
using temporary_type = xsparse<D>; | |
}; | |
template <class D> | |
struct xiterable_inner_types<xsparse<D>> | |
{ | |
using inner_shape_type = xt::dynamic_shape<size_t>; | |
using const_stepper = xindexed_stepper<xsparse<D>, true>; | |
using stepper = xindexed_stepper<xsparse<D>, false>; | |
}; | |
struct xsparse_expression_tag {}; | |
namespace extension | |
{ | |
template <class F, class... CT> | |
struct xfunction_base_impl<xsparse_expression_tag, F, CT...> | |
{ | |
using type = xtensor_empty_base; | |
}; | |
template <class CT, class S, layout_type L, class FST> | |
struct xstrided_view_base_impl<xsparse_expression_tag, CT, S, L, FST> | |
{ | |
using type = xtensor_empty_base; | |
}; | |
} | |
namespace detail | |
{ | |
template <class F, class... E> | |
struct select_xfunction_expression<xsparse_expression_tag, F, E...> | |
{ | |
using type = xfunction<F, E...>; | |
}; | |
} | |
template <> | |
class xexpression_assigner<xsparse_expression_tag> | |
{ | |
public: | |
template <class E1, class E2> | |
static void assign_xexpression(E1& e1, const E2& e2) | |
{ | |
for (auto&& key : e1.keys()) | |
{ | |
std::cout << "Assign xexp"; | |
} | |
} | |
template <class E1, class E2> | |
static void assign_data(xexpression<E1>& e1, const xexpression<E2>& e2, bool trivial) | |
{ | |
} | |
template <class E1, class E2> | |
static void computed_assign(xexpression<E1>& e1, const xexpression<E2>& e2) | |
{ | |
} | |
template <class E1, class E2, class F> | |
static void scalar_computed_assign(xexpression<E1>& e1, const E2& e2, F&& f) | |
{ | |
std::cout << "scalar assign to sparse" << std::endl; | |
for (auto&& el : e1.derived_cast().map()) | |
{ | |
el.second = f(el.second, e2); | |
} | |
} | |
template <class E1, class E2> | |
static void assert_compatible_shape(const xexpression<E1>& e1, const xexpression<E2>& e2) | |
{ | |
} | |
private: | |
template <class E1, class E2> | |
static bool resize(E1& e1, const E2& e2); | |
template <class E1, class F, class... CT> | |
static bool resize(E1& e1, const xfunction<F, CT...>& e2); | |
}; | |
template <class D> | |
class xsparse : public xcontainer_semantic<xsparse<D>>, | |
public xiterable<xsparse<D>> | |
{ | |
public: | |
using expression_tag = xsparse_expression_tag; | |
using value_type = D; | |
using reference = value_type&; | |
using const_reference = const value_type&; | |
using pointer = value_type*; | |
using const_pointer = const value_type*; | |
using difference_type = std::ptrdiff_t; | |
static value_type NA; | |
using self_type = xsparse<D>; | |
using shape_type = xt::dynamic_shape<size_t>; | |
using size_type = std::size_t; | |
using iterable_base = xiterable<self_type>; | |
using inner_shape_type = typename iterable_base::inner_shape_type; | |
using strides_type = get_strides_t<shape_type>; | |
using stepper = typename iterable_base::stepper; | |
using const_stepper = typename iterable_base::const_stepper; | |
constexpr static layout_type static_layout = layout_type::dynamic; | |
static constexpr bool contiguous_layout = false; | |
xsparse() = delete; | |
using iterable_base::begin; | |
using iterable_base::end; | |
xsparse(const shape_type& shape) | |
: m_shape(shape) | |
{ | |
} | |
template <class O> | |
bool broadcast_shape(O& shape, bool trivial = false) const | |
{ | |
return xt::broadcast_shape(m_shape, shape); | |
} | |
template <class ST> | |
bool is_trivial_broadcast(const ST& strides) const | |
{ | |
return false; | |
} | |
template <class... Args> | |
reference operator()(Args... args) | |
{ | |
std::array<size_t, sizeof...(Args)> ix = {static_cast<size_t>(args)...}; | |
return element(ix.begin(), ix.end()); | |
} | |
template <class... Args> | |
reference operator()(Args... args) const | |
{ | |
std::array<size_t, sizeof...(Args)> ix = {static_cast<size_t>(args)...}; | |
return element(ix.begin(), ix.end()); | |
} | |
template <class It> | |
reference element(It begin, It end) | |
{ | |
// std::cout << "getting el" << std::endl; | |
shape_type ix(begin, end); | |
auto it = m_contents.find(ix); | |
if (it == m_contents.end()) | |
{ | |
// throw std::runtime_error("Index doesn't exist my friendo!"); | |
return NA; | |
} | |
return it->second; | |
} | |
template <class It> | |
const_reference element(It begin, It end) const | |
{ | |
// std::cout << "getting elc" << std::endl; | |
shape_type ix(begin, end); | |
auto it = m_contents.find(ix); | |
if (it == m_contents.end()) | |
{ | |
return NA; | |
// throw std::runtime_error("Index doesn't exist my friendo!"); | |
// return 0; | |
} | |
return it->second; | |
} | |
value_type& set(const shape_type& ix) | |
{ | |
if (!(ix < m_shape)) | |
{ | |
throw std::runtime_error("Index > shape."); | |
} | |
m_contents[ix] = value_type(); | |
return m_contents[ix]; | |
} | |
auto& shape() const | |
{ | |
return m_shape; | |
} | |
std::size_t dimension() const | |
{ | |
return m_shape.size(); | |
} | |
template <class ST> | |
stepper stepper_begin(const ST& shape) | |
{ | |
size_type offset = shape.size() - dimension(); | |
return stepper(this, offset); | |
} | |
template <class ST> | |
stepper stepper_end(const ST& shape, layout_type) | |
{ | |
size_type offset = shape.size() - dimension(); | |
return stepper(this, offset, true); | |
} | |
template <class ST> | |
const_stepper stepper_begin(const ST& shape) const | |
{ | |
size_type offset = shape.size() - dimension(); | |
return const_stepper(this, offset); | |
} | |
template <class ST> | |
const_stepper stepper_end(const ST& shape, layout_type) const | |
{ | |
size_type offset = shape.size() - dimension(); | |
return const_stepper(this, offset, true); | |
} | |
layout_type layout() const | |
{ | |
return layout_type::dynamic; | |
} | |
std::map<shape_type, value_type>& map() | |
{ | |
return m_contents; | |
} | |
private: | |
std::map<shape_type, value_type> m_contents; | |
shape_type m_shape; | |
}; | |
template <class D> | |
typename xsparse<D>::value_type xsparse<D>::NA = typename xsparse<D>::value_type(0); | |
} | |
#endif | |
#include <xtensor/xio.hpp> | |
int main() | |
{ | |
xt::xsparse<double> arg({10, 10}); | |
arg.set({0, 0}) = 10; | |
arg.set({2, 3}) = 32; | |
// print array | |
std::cout << arg << std::endl; | |
xt::xsparse<double> b = arg; | |
b *= 123; | |
// print result | |
std::cout << b << std::endl; | |
// select + print single element | |
std::cout << arg(0, 0) << std::endl; | |
// print xfunction | |
// std::cout << arg * 12 << std::endl; | |
// print from STL style iterators | |
// for (auto& el : arg) { std::cout << el << std::endl;} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment