Skip to content

Instantly share code, notes, and snippets.

@wolfv
Created December 17, 2018 14:49
Show Gist options
  • Save wolfv/efa7da7e2f1d9b23f2657d934333b01c to your computer and use it in GitHub Desktop.
Save wolfv/efa7da7e2f1d9b23f2657d934333b01c to your computer and use it in GitHub Desktop.
xtensor sparse experiments
/***************************************************************************
* Copyright (c) 2016, Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
* *
* Distributed under the terms of the BSD 3-Clause License. *
* *
* The full license is in the file LICENSE, distributed with this software. *
****************************************************************************/
#ifndef XTENSOR_SPARSE_HPP
#define XTENSOR_SPARSE_HPP
#include <xtensor/xexpression.hpp>
#include <xtensor/xassign.hpp>
#include <map>
namespace xt
{
template <class D>
class xsparse;
template <class D>
struct xcontainer_inner_types<xsparse<D>>
{
using xexpression_type = xsparse<D>;
using temporary_type = xsparse<D>;
};
template <class D>
struct xiterable_inner_types<xsparse<D>>
{
using inner_shape_type = xt::dynamic_shape<size_t>;
using const_stepper = xindexed_stepper<xsparse<D>, true>;
using stepper = xindexed_stepper<xsparse<D>, false>;
};
struct xsparse_expression_tag {};
namespace extension
{
template <class F, class... CT>
struct xfunction_base_impl<xsparse_expression_tag, F, CT...>
{
using type = xtensor_empty_base;
};
template <class CT, class S, layout_type L, class FST>
struct xstrided_view_base_impl<xsparse_expression_tag, CT, S, L, FST>
{
using type = xtensor_empty_base;
};
}
namespace detail
{
template <class F, class... E>
struct select_xfunction_expression<xsparse_expression_tag, F, E...>
{
using type = xfunction<F, E...>;
};
}
template <>
class xexpression_assigner<xsparse_expression_tag>
{
public:
template <class E1, class E2>
static void assign_xexpression(E1& e1, const E2& e2)
{
for (auto&& key : e1.keys())
{
std::cout << "Assign xexp";
}
}
template <class E1, class E2>
static void assign_data(xexpression<E1>& e1, const xexpression<E2>& e2, bool trivial)
{
}
template <class E1, class E2>
static void computed_assign(xexpression<E1>& e1, const xexpression<E2>& e2)
{
}
template <class E1, class E2, class F>
static void scalar_computed_assign(xexpression<E1>& e1, const E2& e2, F&& f)
{
std::cout << "scalar assign to sparse" << std::endl;
for (auto&& el : e1.derived_cast().map())
{
el.second = f(el.second, e2);
}
}
template <class E1, class E2>
static void assert_compatible_shape(const xexpression<E1>& e1, const xexpression<E2>& e2)
{
}
private:
template <class E1, class E2>
static bool resize(E1& e1, const E2& e2);
template <class E1, class F, class... CT>
static bool resize(E1& e1, const xfunction<F, CT...>& e2);
};
template <class D>
class xsparse : public xcontainer_semantic<xsparse<D>>,
public xiterable<xsparse<D>>
{
public:
using expression_tag = xsparse_expression_tag;
using value_type = D;
using reference = value_type&;
using const_reference = const value_type&;
using pointer = value_type*;
using const_pointer = const value_type*;
using difference_type = std::ptrdiff_t;
static value_type NA;
using self_type = xsparse<D>;
using shape_type = xt::dynamic_shape<size_t>;
using size_type = std::size_t;
using iterable_base = xiterable<self_type>;
using inner_shape_type = typename iterable_base::inner_shape_type;
using strides_type = get_strides_t<shape_type>;
using stepper = typename iterable_base::stepper;
using const_stepper = typename iterable_base::const_stepper;
constexpr static layout_type static_layout = layout_type::dynamic;
static constexpr bool contiguous_layout = false;
xsparse() = delete;
using iterable_base::begin;
using iterable_base::end;
xsparse(const shape_type& shape)
: m_shape(shape)
{
}
template <class O>
bool broadcast_shape(O& shape, bool trivial = false) const
{
return xt::broadcast_shape(m_shape, shape);
}
template <class ST>
bool is_trivial_broadcast(const ST& strides) const
{
return false;
}
template <class... Args>
reference operator()(Args... args)
{
std::array<size_t, sizeof...(Args)> ix = {static_cast<size_t>(args)...};
return element(ix.begin(), ix.end());
}
template <class... Args>
reference operator()(Args... args) const
{
std::array<size_t, sizeof...(Args)> ix = {static_cast<size_t>(args)...};
return element(ix.begin(), ix.end());
}
template <class It>
reference element(It begin, It end)
{
// std::cout << "getting el" << std::endl;
shape_type ix(begin, end);
auto it = m_contents.find(ix);
if (it == m_contents.end())
{
// throw std::runtime_error("Index doesn't exist my friendo!");
return NA;
}
return it->second;
}
template <class It>
const_reference element(It begin, It end) const
{
// std::cout << "getting elc" << std::endl;
shape_type ix(begin, end);
auto it = m_contents.find(ix);
if (it == m_contents.end())
{
return NA;
// throw std::runtime_error("Index doesn't exist my friendo!");
// return 0;
}
return it->second;
}
value_type& set(const shape_type& ix)
{
if (!(ix < m_shape))
{
throw std::runtime_error("Index > shape.");
}
m_contents[ix] = value_type();
return m_contents[ix];
}
auto& shape() const
{
return m_shape;
}
std::size_t dimension() const
{
return m_shape.size();
}
template <class ST>
stepper stepper_begin(const ST& shape)
{
size_type offset = shape.size() - dimension();
return stepper(this, offset);
}
template <class ST>
stepper stepper_end(const ST& shape, layout_type)
{
size_type offset = shape.size() - dimension();
return stepper(this, offset, true);
}
template <class ST>
const_stepper stepper_begin(const ST& shape) const
{
size_type offset = shape.size() - dimension();
return const_stepper(this, offset);
}
template <class ST>
const_stepper stepper_end(const ST& shape, layout_type) const
{
size_type offset = shape.size() - dimension();
return const_stepper(this, offset, true);
}
layout_type layout() const
{
return layout_type::dynamic;
}
std::map<shape_type, value_type>& map()
{
return m_contents;
}
private:
std::map<shape_type, value_type> m_contents;
shape_type m_shape;
};
template <class D>
typename xsparse<D>::value_type xsparse<D>::NA = typename xsparse<D>::value_type(0);
}
#endif
#include <xtensor/xio.hpp>
int main()
{
xt::xsparse<double> arg({10, 10});
arg.set({0, 0}) = 10;
arg.set({2, 3}) = 32;
// print array
std::cout << arg << std::endl;
xt::xsparse<double> b = arg;
b *= 123;
// print result
std::cout << b << std::endl;
// select + print single element
std::cout << arg(0, 0) << std::endl;
// print xfunction
// std::cout << arg * 12 << std::endl;
// print from STL style iterators
// for (auto& el : arg) { std::cout << el << std::endl;}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment