/*
 * tensor.hpp
 *
 */
#ifndef TENSOR_HPP
#define TENSOR_HPP

namespace RFOLD {
template <typename T, typename Data = Array<T> >
class Tensor {
public:
  typedef typename Data::value_type value_type;
  typedef typename Data::iterator iterator;
  typedef typename Data::const_iterator const_iterator;
  typedef typename Data::reference reference;
  typedef typename Data::const_reference const_reference;
  typedef typename Data::size_type size_type;
  typedef typename Data::difference_type difference_type;
  enum {MAX_RANK = 5};
  Tensor() {}
  Tensor(const Tensor& tsr): data_(tsr.data_), strides_(tsr.strides_), ranges_(tsr.ranges_) {}
  ~Tensor() {}
  Tensor& operator=(const Tensor& tsr) {
    data_ = tsr.data_;
    strides_ = tsr.strides_;
    ranges_ = tsr.ranges_;
    return *this;
  }
  bool empty() const {return data_.empty();}
  size_type size() const {return data_.size();}
  iterator begin() {return data_.begin();}
  const_iterator begin() const {return data_.begin();}
  iterator end() {return data_.end();}
  const_iterator end() const {return data_.end();}
  reference front() {return data_.front();}
  const_reference front() const {return data_.front();}
  reference back() {return data_.back();}
  const_reference back() const {return data_.back();}
  void fill(const_reference t) {std::fill(data_.begin(), data_.end(), t);}
  void clear() {data_.clear(); strides_.clear(); ranges_.clear();}
  int rank() const {return ranges_.size();}
  int length(int i_th_arg) const {return ranges_[i_th_arg];}
  Array<int> lengths() const {return ranges_;}
  template <typename InpIt>
  void set_size0(InpIt first, InpIt last) {
    ranges_.assign(first, last);
    Check(rank() <= MAX_RANK);
    strides_.resize(ranges_.size());
    strides_.fill(0);
    int n = 1;
    for (int k = (ranges_.size() - 1); k >= 0; k--) {
      int l = ranges_[k];
      if (l > 0) {
	strides_[k] = n;
	n *= l;
      } else if (l == 0) {
	// no-op
      } else {
	Die("ranges[%d] = %d < 0", k, l);
      }
    }
    data_.resize(n);
  }
  void set_size() {int l[] = {}; set_size0(l, l + 0);}
  void set_size(int i0) {int l[] = {i0}; set_size0(l, l + 1);}
  void set_size(int i0, int i1) {int l[] = {i0,i1}; set_size0(l, l + 2);}
  void set_size(int i0, int i1, int i2) {int l[] = {i0,i1,i2}; set_size0(l, l + 3);}
  void set_size(int i0, int i1, int i2, int i3) {
    int l[] = {i0,i1,i2,i3}; set_size0(l, l + 4);}
  void set_size(int i0, int i1, int i2, int i3, int i4) {
    int l[] = {i0,i1,i2,i3,i4}; set_size0(l, l + 5);}
  reference ref() {
    Assert(rank()==0); int l[]={}; return data_[get_index0<0, const int*>(l)];}
  reference ref(int i0) {
    Assert(rank()==1); int l[]={i0}; return data_[get_index0<1, const int*>(l)];}
  reference ref(int i0, int i1) {
    Assert(rank()==2); int l[]={i0,i1}; return data_[get_index0<2, const int*>(l)];}
  reference ref(int i0, int i1, int i2) {
    Assert(rank()==3); int l[]={i0,i1,i2}; return data_[get_index0<3, const int*>(l)];}
  reference ref(int i0, int i1, int i2, int i3) {
    Assert(rank()==4); int l[]={i0,i1,i2,i3}; return data_[get_index0<4, const int*>(l)];}
  reference ref(int i0, int i1, int i2, int i3, int i4) {
    Assert(rank()==5); int l[]={i0,i1,i2,i3,i4}; return data_[get_index0<5, const int*>(l)];}
  const_reference ref() const {
    Assert(rank()==0); int l[]={}; return data_[get_index0<0, const int*>(l)];}
  const_reference ref(int i0) const {
    Assert(rank()==1); int l[]={i0}; return data_[get_index0<1, const int*>(l)];}
  const_reference ref(int i0, int i1) const {
    Assert(rank()==2); int l[]={i0,i1}; return data_[get_index0<2, const int*>(l)];}
  const_reference ref(int i0, int i1, int i2) const {
    Assert(rank()==3); int l[]={i0,i1,i2}; return data_[get_index0<3, const int*>(l)];}
  const_reference ref(int i0, int i1, int i2, int i3) const {
    Assert(rank()==4); int l[]={i0,i1,i2,i3}; return data_[get_index0<4, const int*>(l)];}
  const_reference ref(int i0, int i1, int i2, int i3, int i4) const {
    Assert(rank()==5); int l[]={i0,i1,i2,i3,i4}; return data_[get_index0<5, const int*>(l)];}
  // template <typename InpIt>
  // reference ref0(InpIt first, InpIt last) {
  //   Assert(rank()==(int)std::distance(first, last)); 
  //   return data_[get_index(std::distance(first,last), first)];
  // }
  // template <typename InpIt>
  // const_reference ref0(InpIt first, InpIt last) const {
  //   Assert(rank()==(int)std::distance(first, last)); 
  //   return data_[get_index((int)std::distance(first, last), first)];
  // }
  int get_index(const Array<int>& idxs) const {
    Assert(rank() == (int)idxs.size());
    return inner_product(strides_.begin(), strides_.end(), idxs.begin(), 0);
  }
  Array<int> get_indexes(int n) const {
    Array<int> indexes(rank());
    for (int k = (rank() - 1); k >= 0; k--) {
      int l = ranges_[k];
      int r = (-1);
      if (l > 0) {
	r = n % l;
	n /= l;
      }
      indexes[k] = r;
    }
    return indexes;
  }
  std::string to_s() const {
    std::ostringstream oss;
    for (int i = 0; i < (int)data_.size(); i++) {
      const Array<int>& indexes = get_indexes(i);
      oss << "[ ";
      for (int k = 0; k < (int)indexes.size(); k++) {
	oss << indexes[k] 
	    << (k == ((int)indexes.size() - 1) ? " " : ", ");
      }
      oss << "]: ";
      oss << data_[i] << "\n";
    }
    return oss.str();
  }
  void print() const {std::cout << to_s() << std::flush;}
  
private:
  Data data_;
  Array<int> strides_;
  Array<int> ranges_;

  template <int RANK, typename InpIt>
  inline int get_index0(InpIt p) const {
    int idx = 0;
    Array<int>::const_iterator it = strides_.begin(); 
    for (int m = 0; m < RANK; m++) {
#ifdef DEBUG
      int i = (*p);
      int k = std::distance(strides_.begin(), it);
      Assert((ranges_[k] == 0) || (0 <= i && i < ranges_[k]),
	     "ranges_[%d]=%d, strides_[%d]=%d, i=%d, rank=%d, m=%d, idx=%d", 
	     k, ranges_[k], k, strides_[k], i, rank(), m, idx);
#endif
      idx += ((*p++) * (*it++));
    }
    return idx;
  }
};

template <typename T, int MAX_RANK, typename Data = Array<T> >
class Tensor1 {
public:
  typedef typename Data::value_type value_type;
  typedef typename Data::iterator iterator;
  typedef typename Data::const_iterator const_iterator;
  typedef typename Data::reference reference;
  typedef typename Data::const_reference const_reference;
  typedef typename Data::size_type size_type;
  typedef typename Data::difference_type difference_type;

  Tensor1() {}
  Tensor1(const Tensor1& tsr): data_(tsr.data_), strides_(tsr.strides_), ranges_(tsr.ranges_) {}
  ~Tensor1() {}
  Tensor1& operator=(const Tensor1& tsr) {
    data_ = tsr.data_;
    strides_ = tsr.strides_;
    ranges_ = tsr.ranges_;
    return *this;
  }
  bool empty() const {return data_.empty();}
  size_type size() const {return data_.size();}
  iterator begin() {return data_.begin();}
  const_iterator begin() const {return data_.begin();}
  iterator end() {return data_.end();}
  const_iterator end() const {return data_.end();}
  reference front() {return data_.front();}
  const_reference front() const {return data_.front();}
  reference back() {return data_.back();}
  const_reference back() const {return data_.back();}
  void fill(const_reference t) {std::fill(data_.begin(), data_.end(), t);}
  void clear() {data_.clear(); strides_.clear(); ranges_.clear();}
  int rank() const {return ranges_.size();}
  int length(int i_th_arg) const {return ranges_[i_th_arg];}
  Array<int> lengths() const {return ranges_;}
  template <typename InpIt>
  void set_size0(InpIt first, InpIt last) {
    ranges_.assign(first, last);
    Check(ranges_.size() <= MAX_RANK);
    // strides_.resize(ranges_.size());
    strides_.fill(0);
    int n = 1;
    for (int k = (ranges_.size() - 1); k >= 0; k--) {
      int l = ranges_[k];
      if (l > 0) {
	strides_[k] = n;
	n *= l;
      } else if (l == 0) {
	// no-op
      } else {
	Die("ranges[%d] = %d < 0", k, l);
      }
    }
    data_.resize(n);
  }
  void set_size() {int l[] = {}; set_size0(l, l + 0);}
  void set_size(int i0) {int l[] = {i0}; set_size0(l, l + 1);}
  void set_size(int i0, int i1) {int l[] = {i0, i1}; set_size0(l, l + 2);}
  void set_size(int i0, int i1, int i2) {int l[] = {i0, i1, i2}; set_size0(l, l + 3);}
  void set_size(int i0, int i1, int i2, int i3) {
    int l[] = {i0, i1, i2, i3}; set_size0(l, l + 4);}
  void set_size(int i0, int i1, int i2, int i3, int i4) {
    int l[] = {i0, i1, i2, i3, i4}; set_size0(l, l + 5);}
  reference ref() {
    Assert(rank()==0);int l[] = {}; return data_[get_index(0, l)];}
  reference ref(int i0) {
    Assert(rank()==1); int l[]={i0}; return data_[get_index(1, l)];}
  reference ref(int i0, int i1) {
    Assert(rank()==2); int l[]={i0,i1}; return data_[get_index(2, l)];}
  reference ref(int i0, int i1, int i2) {
    Assert(rank()==3); int l[]={i0,i1,i2}; return data_[get_index(3, l)];}
  reference ref(int i0, int i1, int i2, int i3) {
    Assert(rank()==4); int l[]={i0,i1,i2,i3}; return data_[get_index(4, l)];}
  reference ref(int i0, int i1, int i2, int i3, int i4) {
    Assert(rank()==5); int l[]={i0,i1,i2,i3,i4}; return data_[get_index(5, l)];}
  const_reference ref() const {
    Assert(rank()==0);int l[] = {}; return data_[get_index(0, l)];}
  const_reference ref(int i0) const {
    Assert(rank()==1); int l[]={i0}; return data_[get_index(1, l)];}
  const_reference ref(int i0, int i1)const  {
    Assert(rank()==2); int l[]={i0,i1}; return data_[get_index(2, l)];}
  const_reference ref(int i0, int i1, int i2) const {
    Assert(rank()==3); int l[]={i0,i1,i2}; return data_[get_index(3, l)];}
  const_reference ref(int i0, int i1, int i2, int i3) const {
    Assert(rank()==4); int l[]={i0,i1,i2,i3}; return data_[get_index(4, l)];}
  const_reference ref(int i0, int i1, int i2, int i3, int i4) const {
    Assert(rank()==5); int l[]={i0,i1,i2,i3,i4}; return data_[get_index(5, l)];}
  template <typename InpIt>
  reference ref0(InpIt first, InpIt last) {
    Assert(rank()==(int)std::distance(first, last)); 
    return data_[get_index(std::distance(first,last), first)];
  }
  template <typename InpIt>
  const_reference ref0(InpIt first, InpIt last) const {
    Assert(rank()==(int)std::distance(first, last)); 
    return data_[get_index((int)std::distance(first, last), first)];
  }
  int get_index(const Array<int>& idxs) const {
    Assert(rank() == (int)idxs.size());
    return get_index(idxs.size(), idxs.begin());
  }
  Array<int> get_indexes(int n) const {
    Array<int> indexes(rank());
    for (int k = (rank() - 1); k >= 0; k--) {
      int l = ranges_[k];
      int r = (-1);
      if (l > 0) {
	r = n % l;
	n /= l;
      }
      indexes[k] = r;
    }
    return indexes;
  }
  std::string to_s() const {
    std::ostringstream oss;
    for (int i = 0; i < (int)data_.size(); i++) {
      const Array<int>& indexes = get_indexes(i);
      oss << "[ ";
      for (int k = 0; k < (int)indexes.size(); k++) {
	oss << indexes[k] 
	    << (k == ((int)indexes.size() - 1) ? " " : ", ");
      }
      oss << "]: ";
      oss << data_[i] << "\n";
    }
    return oss.str();
  }
  void print() const {std::cout << to_s() << std::flush;}
  
private:
  Data data_;
  //Array<int> strides_;
  CArray<int, MAX_RANK> strides_;
  Array<int> ranges_;

  template <typename InpIt>
  int get_index(int m, InpIt p) const {
    int idx = 0;
    Array<int>::const_iterator it = strides_.begin(); 
    while (m-- > 0) {
#ifdef DEBUG
      int i = (*p);
      int k = std::distance(strides_.begin(), it);
      Assert((ranges_[k] == 0) || (0 <= i && i < ranges_[k]),
	     "ranges_[%d]=%d, strides_[%d]=%d, i=%d, rank=%d, m=%d, idx=%d", 
	     k, ranges_[k], k, strides_[k], i, rank(), m, idx);
#endif
      idx += ((*p++) * (*it++));
    }
    return idx;
  }
  // int get_index(int narg, ...) const {
  //   Assert((int)ranges_.size() == narg);
  //   std::va_list ap;
  //   std::va_start(ap, narg);
  //   int idx = 0;
  //   for (int k = 0; k < narg; k++) {
  //     int i = std::val_arg(ap, int);
  //     Assert((ranges_[k] == 0) || (0 <= i && i < ranges_[k]));
  //     idx += (i * strides_[k]);
  //   }
  //   std::va_end(ap);
  //   return idx;
  // }
};

template <typename T>
class TensorBase {
public:
  typedef T value_type;
  typedef T* iterator;
  typedef T const* const_iterator;
  typedef T& reference;
  typedef T const& const_reference;
  typedef std::size_t size_type;
  typedef int difference_type;

  virtual ~TensorBase() {}

  virtual bool empty() const = 0;
  virtual size_type size() const = 0;
  virtual iterator begin() = 0;
  virtual const_iterator begin() const = 0;
  virtual iterator end() = 0;
  virtual const_iterator end() const = 0;
  virtual reference front() = 0;
  virtual const_reference front() const = 0;
  virtual reference back() = 0;
  virtual const_reference back() const = 0;
  virtual void fill(const_reference t) = 0;
  virtual void clear() = 0;
  virtual size_type rank() const = 0;
  virtual size_type  max_rank() const = 0;
  virtual size_type length(int i) const = 0;
  virtual Array<int> lengths() const = 0;
  virtual Array<int> strides() const = 0;
  // member template may not be virtual
  // template <typename InpIt>
  //virtual void set_size0(InpIt first, InpIt last) {}
  virtual void set_size() {}
  virtual void set_size(int i0) {}
  virtual void set_size(int i0, int i1) {}
  virtual void set_size(int i0, int i1, int i2) {}
  virtual void set_size(int i0, int i1, int i2, int i3) {}
  virtual void set_size(int i0, int i1, int i2, int i3, int i4) {}
  virtual reference ref() = 0;
  virtual reference ref(int i0) = 0;
  virtual reference ref(int i0, int i1) = 0;
  virtual reference ref(int i0, int i1, int i2) = 0;
  virtual reference ref(int i0, int i1, int i2, int i3) = 0;
  virtual const_reference ref() const = 0;
  virtual const_reference ref(int i0) const = 0;
  virtual const_reference ref(int i0, int i1) const = 0;
  virtual const_reference ref(int i0, int i1, int i2) const = 0;
  virtual const_reference ref(int i0, int i1, int i2, int i3) const = 0;

  virtual int get_index(const Array<int>& idxs) const = 0;
  virtual Array<int> get_indexes(int n) const = 0;
  virtual std::string to_s() const {return "";}
  virtual void print() const {std::cout << to_s() << std::flush;}
};


template <typename T, int L0 = (-1), int L1 = (-1), int L2 = (-1), int L3 = (-1)>
class CTensor : public TensorBase<T> {// Lk should be set to non negative values
  typedef TensorBase<T> Super;
public:
#define L_P_(k) (L##k != (-1))
#define LEN_(k) (L##k > 0 ? L##k : 1)
  enum {
    MAX_RANK = 4,
    RANK     = (L_P_(0) + L_P_(1) + L_P_(2) + L_P_(3)),
    SIZE     = (LEN_(0) * LEN_(1) * LEN_(2) * LEN_(3) * 1),
    STRIDE0  = (L0<=0 ? 0 : (LEN_(1) * LEN_(2) * LEN_(3) * 1)),
    STRIDE1  = (L1<=0 ? 0 : (LEN_(2) * LEN_(3) * 1)),
    STRIDE2  = (L2<=0 ? 0 : (LEN_(3) * 1)),
    STRIDE3  = (L3<=0 ? 0 : (1))
  };
#undef L_P_
#undef LEN_
  
  typedef typename Super::value_type value_type;
  typedef typename Super::iterator iterator;
  typedef typename Super::const_iterator const_iterator;
  typedef typename Super::reference reference;
  typedef typename Super::const_reference const_reference;
  typedef typename Super::size_type size_type;
  typedef typename Super::difference_type difference_type;

#if true
  typedef CArray<T, SIZE> Data;
#else
  typedef Array<T> Data;
  CTensor() : data_(SIZE) {}
  virtual ~CTensor() {}
#endif

  bool empty() const {return data_.empty();}
  size_type size() const {return data_.size();}
  iterator begin() {return data_.begin();}
  const_iterator begin() const {return data_.begin();}
  iterator end() {return data_.end();}
  const_iterator end() const {return data_.end();}
  reference front() {return data_.front();}
  const_reference front() const {return data_.front();}
  reference back() {return data_.back();}
  const_reference back() const {return data_.back();}
  void fill(const_reference t) {data_.fill(t);}
  void clear() {data_.clear();}
  size_type rank() const {return RANK;}
  size_type max_rank() const {return MAX_RANK;}
  size_type length(int i) const {
    const CArray<int, MAX_RANK> arr0 = {{L0, L1, L2, L3}};
    return arr0[i];
  }
  Array<int> lengths() const {
    const CArray<int, MAX_RANK> arr0 = {{L0, L1, L2, L3}};
    return Array<int>(arr0.begin(), arr0.begin() + rank());
  }
  Array<int> strides() const {
    const CArray<int, MAX_RANK> arr0 = {{STRIDE0, STRIDE1, STRIDE2, STRIDE3}};
    return Array<int>(arr0.begin(), arr0.begin() + rank());
  }
  // template <typename InpIt>
  // void set_size0(InpIt first, InpIt last) {}
  void set_size() {}
  void set_size(int i0) {}
  void set_size(int i0, int i1) {}
  void set_size(int i0, int i1, int i2) {}
  void set_size(int i0, int i1, int i2, int i3) {}
  void set_size(int i0, int i1, int i2, int i3, int i4) {}

#define PROD(k) (STRIDE##k * i##k)
#define GET_INDEX_0()               (0)
#define GET_INDEX_1(i0)             (PROD(0))
#define GET_INDEX_2(i0, i1)         (PROD(0)+PROD(1))
#define GET_INDEX_3(i0, i1, i2)     (PROD(0)+PROD(1)+PROD(2))
#define GET_INDEX_4(i0, i1, i2, i3) (PROD(0)+PROD(1)+PROD(2)+PROD(3))
#define CHECK_RANK(n) do {if ((n) != RANK) Die("bad rank %d != %d", (n), RANK);} while (0)
  reference ref() {
    CHECK_RANK(0); return data_[GET_INDEX_0()];}
  reference ref(int i0) {
    CHECK_RANK(1); return data_[GET_INDEX_1(i0)];}
  reference ref(int i0, int i1) {
    CHECK_RANK(2); return data_[GET_INDEX_2(i0, i1)];}
  reference ref(int i0, int i1, int i2) {
    CHECK_RANK(3); return data_[GET_INDEX_3(i0, i1, i2)];}
  reference ref(int i0, int i1, int i2, int i3) {
    CHECK_RANK(4); return data_[GET_INDEX_4(i0, i1, i2, i3)];}
  const_reference ref() const {
    CHECK_RANK(0); return data_[GET_INDEX_0()];}
  const_reference ref(int i0) const {
    CHECK_RANK(1); return data_[GET_INDEX_1(i0)];}
  const_reference ref(int i0, int i1)const  {
    CHECK_RANK(2); return data_[GET_INDEX_2(i0, i1)];}
  const_reference ref(int i0, int i1, int i2) const {
    CHECK_RANK(3); return data_[GET_INDEX_3(i0, i1, i2)];}
  const_reference ref(int i0, int i1, int i2, int i3) const {
    CHECK_RANK(4); return data_[GET_INDEX_4(i0, i1, i2, i3)];}
#undef CHECK_RANK
#undef PROD
#undef GET_INDEX_0
#undef GET_INDEX_1
#undef GET_INDEX_2
#undef GET_INDEX_3
#undef GET_INDEX_4

  int get_index(const Array<int>& idxs) const {
    Assert(rank() == idxs.size());
    const Array<int>& arr = strides();
    return inner_product(arr.begin(), arr.end(), idxs.begin(), 0);
  }
  Array<int> get_indexes(int n) const {
    const Array<int>& ranges = lengths();
    Array<int> indexes(rank());
    for (int k = (rank() - 1); k >= 0; k--) {
      int l = ranges[k];
      int r = (-1);
      if (l > 0) {
	r = n % l;
	n /= l;
      }
      indexes[k] = r;
    }
    return indexes;
  }
  std::string to_s() const {
    std::ostringstream oss;
    for (int i = 0; i < (int)data_.size(); i++) {
      const Array<int>& indexes = get_indexes(i);
      oss << "[ ";
      for (int k = 0; k < (int)indexes.size(); k++) {
	oss << indexes[k] 
	    << (k == ((int)indexes.size() - 1) ? " " : ", ");
      }
      oss << "]: ";
      oss << data_[i] << "\n";
    }
    return oss.str();
  }
  void print() const {std::cout << to_s() << std::flush;}
  
  Data data_;
};
}
#endif
