⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 tensor.h

📁 Tensor类
💻 H
字号:
/* Tensor.H                                 Version from  05/02/97 17:07PM  */
/* Contents ----------------------------------------------------------------**
**									                                                        **
**  Class 	Tensor                                                          **
**																		                                      **
**--------------------------------------------------------------------------**
**                                                                          **
** COPYRIGHT (C) 1997 by Melnikov Mike. All rights reserved.                **
** For any comments or suggestions mailto:zmike@andnow.ru                   **
**                                                                          **
** -------------------------------------------------------------------------*/

#ifndef _TENSOR_
#define _TENSOR_

#include <string>
#include <OSTREAM>
#include <assert.h>

//-------------------------------------------------------------------------------------
// ----- TensorFormatString                  ------------------------------------------
//-------------------------------------------------------------------------------------

class TensorFormatString
{
  public:
    TensorFormatString() : m_beginSymbol('1') { m_format[0] ='\0';}
    TensorFormatString( const TensorFormatString&  str ) : m_beginSymbol('1') { operator=(str);}
    TensorFormatString( const char*  str ) : m_beginSymbol('1') { operator=(str);}

    virtual ~TensorFormatString()  {}

    operator const char*()  const                               { return getFormat(); }  

    TensorFormatString&     operator =( TensorFormatString& str);
    TensorFormatString&     operator =( const char* str );

    TensorFormatString&     operator =( const int  m_value );

    const char*             getFormat() const{ return m_format; } 

    operator char() const   {return m_format[0]; } 

    const int operator []   ( int index ) const    { return m_format[index] - m_beginSymbol; }

    operator bool() const   { return strlen(m_format)>0 ? true : false; }

    int                     search(const char  symbol);

    void                    setValue( int index, int m_value ) { assert(m_format);m_format[index] = m_beginSymbol + m_value;}

    int                     getm_beginSymbol()          { return m_beginSymbol-'0'; }
    void                    setm_beginSymbol(int begin) { m_beginSymbol = begin + '0'; }
    int                     isdigit(int index) { return ::isdigit(m_format[index]); }

  private:

    char                    m_format[10];
    char                    m_beginSymbol;
};

//-------------------------------------------------------------------------------------
// ----- TensorObject  - public for Tensors--------------------------------------------
//-------------------------------------------------------------------------------------
class  TensorObject
{
public:
    virtual TensorObject& operator [] ( const char *str ) = 0 ;
    virtual TensorObject& operator [] ( TensorFormatString& str ) = 0;

    virtual const TensorFormatString& getFormat() { return m_format;}

    int                               getLevel()  { return m_level;}

    virtual void                      printf(std::ostream &op,char* str=NULL) = 0;

    virtual void                      internal_add(TensorObject& o1,TensorObject& o2) {assert(false);}

protected:
    int                               m_level;
    TensorFormatString                m_format;
};

//-------------------------------------------------------------------------------------
// ----- Template class Tensor        -------------------------------------------------
//-------------------------------------------------------------------------------------
template <class Element>
class Tensor : public TensorObject{

  public :
      Tensor(int n=2);
      Tensor(const Tensor<Element>& t);
      ~Tensor();

      Tensor<Element>& operator = ( const Tensor<Element>& );

      const Element& operator [] ( int index ) const { return m_items[index]; }
      Element&       operator [] ( int index )       { return m_items[index]; }

      // return element at first place of str and remember leave part in m_format of this element
      // if this is not a digit remember whole string
      virtual TensorObject& operator [] ( const char *str ) { return operator[](TensorFormatString(str));}
      virtual TensorObject& operator [] ( TensorFormatString& m_format );


      Tensor<Element>& operator~(); // clear - make it "zero"

      Tensor<Element>  operator+ ( const Tensor<Element>& t);
      Tensor<Element>  operator- ( const Tensor<Element>& t);
      Tensor<Element>& operator-=( const Tensor<Element>& t);
      Tensor<Element>& operator+=( const Tensor<Element>& t);
      Tensor<Element>& operator*=( double num);
      Tensor<Element>  operator* ( double num);

      int                       getDim() const { return m_dim;}

      Tensor<Element>&          averaging(Tensor<Element>* ten,double *r,int count );

      virtual void              printf(std::ostream &op,char* str=NULL);

  public:
      void                      init(int n);

  private:

      Element                   *m_items;
      int                       m_dim;
};

//-------------------------------------------------------------------------------------
// ----- TensorElement - T m_value -------------------------------------------------
//-------------------------------------------------------------------------------------
typedef double _Double;

template<typename T = _Double>
class TensorElement : public TensorObject
{
public:
    TensorElement()  : m_value(0) { m_level=0;}
    TensorElement(T val)  : m_value(val) { m_level=0;}


    operator T() const                                     { return m_value;}

    virtual TensorElement& operator = ( const T& val)      { m_value = val; return *this;}

    TensorElement& operator = ( const TensorElement& ten)  { m_value = T(ten);return *this;}

    virtual TensorObject&  operator [] ( const char *str ) {return *this;}
    virtual TensorObject&  operator [] (TensorFormatString& m_format ) {return *this; }

    TensorObject&   operator~() { m_value = T(); return *this;}

    virtual void    printf(std::ostream &op,char* str)      { op << m_value << ' '; }

    virtual void    internal_add(TensorObject& o1,TensorObject& o2) {m_value = m_value + 
                                                                              ((TensorElement<T>&)o1) * ((TensorElement<T>&)o2);
                                                                      m_value=m_value;}
    
    TensorElement&  operator+=( const T& t) {m_value += t; return *this;}

public:
    void                          init(int n) {}

private:
    T                             m_value;
};




typedef Tensor<TensorElement<> >  Tensor1;
typedef Tensor<Tensor1>        Tensor2;
typedef Tensor<Tensor2>        Tensor3;
typedef Tensor<Tensor3>        Tensor4;

template <class Element>
Tensor<Element>::Tensor(int n)
{
    init(n);
}

template <class Element>
void Tensor<Element>::init(int n)
{
    m_items = new Element[n];
    m_level = m_items[0].getLevel() + 1;
    m_dim = n;
    for( int i=0; i<n; i++)
        m_items[i].init(n);
}

template <class Element>
Tensor<Element>::Tensor(const Tensor<Element>& t)
{
    //*************
    //if(!m_items)
    {
        m_items = new Element[2];             // it is not that I wanted
        m_level = m_items[0].getLevel() + 1;
        m_dim = 2;
    }
    //************

    operator=(t);
}

template <class Element>
Tensor<Element>::~Tensor()
{
    delete [] m_items;
}

template <class Element>
Tensor<Element>& Tensor<Element>::operator~()
{
    for(int i=0; i < getDim(); i++)
        ~(m_items[i]);
    return *this;
}

template <class Element>
Tensor<Element>& Tensor<Element>::operator = ( const Tensor<Element>& t )
{
    for(int i=0; i < t.m_dim; i++)
        m_items[i] = t[i];
    return *this;
}

template <class Element>
TensorObject& Tensor<Element>::operator [] ( TensorFormatString& str )
{
    if( str.isdigit(0) )
    {
        int iItem = str[0];
        return m_items[iItem][str.getFormat() + 1];
    }
    else
    {
        m_format = str;
        return *this;
    }
}

template <class Element>
Tensor<Element> Tensor<Element>::operator+ ( const Tensor<Element>& t)
{
    Tensor<Element> tmp;
    for(int i=0; i < getDim(); i++)
        tmp[i] = m_items[i] + t[i];
    return tmp;
}

template <class Element>
Tensor<Element> Tensor<Element>::operator- ( const Tensor<Element>& t)
{
    Tensor<Element> tmp;
    for(int i=0; i < getDim(); i++)
        tmp[i] = m_items[i] - t[i];
    return tmp;
}

template <class Element>
Tensor<Element>& Tensor<Element>::operator-= ( const Tensor<Element>& t)
{
    for(int i=0; i < t.m_dim; i++)
        m_items[i] -= t[i];
    return *this;
}

template <class Element>
Tensor<Element>& Tensor<Element>::operator+= ( const Tensor<Element>& t)
{
    for(int i=0; i < t.m_dim; i++)
        m_items[i] += t[i];
    return *this;
}

template <class Element>
Tensor<Element>& Tensor<Element>::operator*= ( double num)
{
    for(int i=0; i < getDim(); i++)
        m_items[i] *= num;
    return *this;
}

template <class Element>
Tensor<Element> Tensor<Element>::operator* (double num)
{
    Tensor<Element> tmp;
    for(int i=0; i < getDim(); i++)
        tmp[i] = m_items[i] * num;
    return tmp;
}

Tensor4 inverse(Tensor4& t);

template <class Element>
void Tensor<Element>::printf(std::ostream &op,char* str)
{
    if( str ) 
      op << str;
	
    TensorFormatString oldFormat(getFormat());
    //fprintf(op,"\r\n%s\r\n",(const char*)m_format);

    for(int i=0; i < getDim(); i++)
    {
      oldFormat = i;
  		operator[](oldFormat).printf(op);
    }
    op << '\n';
}

template <class Element>
Tensor<Element>& Tensor<Element>::averaging(Tensor<Element>* ten,double *r,int count )
{
    ~(*this);
    for(int i=0;i<count;i++)
        *this += ten[i] * r[i];
	return *this;
}

// the main reason I've wrote this code
void convolution( TensorObject& result, TensorObject& tenA, TensorObject& tenB );

#endif

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -