⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 neuronview.cpp

📁 神经网络中的多层感知器的BP(反向传播)学习算法
💻 CPP
字号:
// NeuronView.cpp : implementation file
//

#include "stdafx.h"
#include "xor_learn.h"
#include "NeuronView.h"

#ifdef _DEBUG
#define new DEBUG_NEW
#undef THIS_FILE
static char THIS_FILE[] = __FILE__;
#endif

#define BASE_UNIT 50.0
#define POINT_UNIT 3.0

#define TIMER_ID 0xf0f0f0f0
/////////////////////////////////////////////////////////////////////////////
// CNeuronView

CNeuronView::CNeuronView()
{
    m_bOver = FALSE;
    m_hCursor = AfxGetApp()->LoadStandardCursor(IDC_CROSS);
    m_pWeightsListener = NULL;
    m_pMouseListener = NULL;
    m_pNeuralNetWork = NULL;
    //m_pNeuron = NULL;
    m_uiTimerID = TIMER_ID;
    m_uiTimeout = 1;
    m_iCurrentPoint = -1;
    
    CPointf p;
    int iSize;
    ifstream fin("./pattern/pattern.patt");
    
    if (fin.is_open())
    {
        fin >> iSize;

        m_vecPoint.clear();
        m_vecPointClass.clear();
        m_vecPoint.resize(iSize);
        m_vecPointClass.resize(iSize);

        for (int i = 0; i < iSize; i++)
        {   
            fin >> m_vecPoint[i].x >> m_vecPoint[i].y >> m_vecPointClass[i];
        }
    }
}

CNeuronView::~CNeuronView()
{

}


BEGIN_MESSAGE_MAP(CNeuronView, CStatic)
	//{{AFX_MSG_MAP(CNeuronView)
	ON_WM_PAINT()
    ON_MESSAGE(WM_MOUSELEAVE, OnMouseLeave)
	ON_MESSAGE(WM_MOUSEHOVER, OnMouseHover)
	ON_WM_MOUSEMOVE()
	ON_WM_SETCURSOR()
	ON_WM_TIMER()
	ON_WM_LBUTTONDOWN()
	//}}AFX_MSG_MAP
END_MESSAGE_MAP()

/////////////////////////////////////////////////////////////////////////////
// CNeuronView message handlers
void CNeuronView::OnPaint() 
{
    static CPen penPoint(PS_SOLID, 1, RGB(74, 88, 41));
    static CPen penLine1(PS_SOLID, 1, RGB(0, 0, 255));
    static CPen penLine2(PS_SOLID, 1, RGB(174, 88, 41));
    static CPen penLine3(PS_SOLID, 1, RGB(74, 188, 41));
    static CBrush brushBg(RGB(255, 255, 255));
    static CBrush burshCurrentPoint(RGB(0, 255, 0));
    static CBrush brushPoint[7] = 
    {
        RGB(0, 255, 255),
        RGB(0, 128, 128),
        RGB(0, 255, 0),
        RGB(255, 255, 0),
        RGB(255, 0, 0),
        RGB(128, 128, 0),
        RGB(128, 0, 0)    
    };
    
    CBitmap memBitMap;
    CBitmap *pOldBit;
    CBrush *pOldBrush;
    CPen *pOldPen;
    CRect rectSrc, rectDst, rect1, rect2;
    int cx, cy;
    CDC *pDC;
    CDC memDC;
    
    pDC = GetDC();
    GetClientRect(&rectDst);
    memDC.CreateCompatibleDC(pDC);
    memBitMap.CreateCompatibleBitmap(pDC, rectDst.Width(), rectDst.Height());

    pOldPen = memDC.GetCurrentPen();
    pOldBrush = memDC.GetCurrentBrush();
    pOldBit = memDC.SelectObject(&memBitMap);

    /* move the rect to (0, 0) */
    rectSrc.CopyRect(&rectDst);
    rectSrc.OffsetRect(-rectDst.left, -rectDst.top);


    rect1.CopyRect(&rectSrc);
    rect2.CopyRect(&rectSrc);

    cx = (rectSrc.Width()) / 4;
    cy = (rectSrc.Height()) / 2;

    rect1.right = 2 * cx;
    rect2.left = 2 * cx;

    /* draw background */
    memDC.FillRect(&rectSrc, &brushBg);
    
    /* draw box */
    memDC.Rectangle(&rect1);
    memDC.Rectangle(&rect2);
    
    DrawCoordinate(memDC, rect1);
    DrawCoordinate(memDC, rect2);

    /* draw points */
    memDC.SelectObject(&penPoint);
    CPointf p;
    CPoint pL;
    for (int i = 0; i < m_vecPoint.size(); i++)
    {
        CPoint pL;
        p = m_vecPoint[i];
        if (i == m_iCurrentPoint)
            memDC.SelectObject(&burshCurrentPoint);
        else
            memDC.SelectObject(&brushPoint[m_vecPointClass[i] + 3]);
        MPtoLP(rect1, &p, &pL);
        if (rect1.PtInRect(pL))
            memDC.Ellipse(pL.x - POINT_UNIT, pL.y - POINT_UNIT, 
                          pL.x + POINT_UNIT, pL.y + POINT_UNIT);
    }    
    
    /* draw line */
    double w1, w2, b;
    if (m_pNeuralNetWork)
    {
        memDC.SelectObject(&penLine1);
        w1 = m_pNeuralNetWork->getWeight(1, 0, 0);
        w2 = m_pNeuralNetWork->getWeight(1, 1, 0);
        b = m_pNeuralNetWork->getB(1, 0);
        
        DrawLine(memDC, rect1, w1, w2, b);
        
        memDC.SelectObject(&penLine2);
        w1 = m_pNeuralNetWork->getWeight(1, 0, 1);
        w2 = m_pNeuralNetWork->getWeight(1, 1, 1);
        b = m_pNeuralNetWork->getB(1, 1);
        
        DrawLine(memDC, rect1, w1, w2, b);

        memDC.SelectObject(&penLine3);
        w1 = m_pNeuralNetWork->getWeight(2, 0, 0);
        w2 = m_pNeuralNetWork->getWeight(2, 1, 0);
        b = m_pNeuralNetWork->getB(2, 0);
        
        DrawLine(memDC, rect2, w1, w2, b);
    }

    pDC->BitBlt(rectDst.left, rectDst.top, rectDst.Width(), rectDst.Height(),
                &memDC, rectSrc.left, rectSrc.top, SRCCOPY);

    memDC.SelectObject(pOldPen);
    memDC.SelectObject(pOldBrush);
    memDC.SelectObject(pOldBit);
    memDC.DeleteDC();
    memBitMap.DeleteObject();
    ReleaseDC(pDC);
    CStatic::OnPaint();
	// Do not call CStatic::OnPaint() for painting messages
}

void CNeuronView::MPtoLP(CRect& rect, CPointf *pointf, CPoint *point)
{
    double ax, bx, ay, by;

    int cx = rect.left + (rect.right - rect.left) / 2.0;
    int cy = rect.top + (rect.bottom - rect.top) / 2.0;

	ax = 1.0 / BASE_UNIT;
    bx = -cx / BASE_UNIT;
    ay = -1.0 / BASE_UNIT;
    by = cy / BASE_UNIT;

    point->x = (long)((pointf->x - bx) / ax);
    point->y = (long)((pointf->y - by) / ay);
}

void CNeuronView::LPtoMP(CRect& rect, CPoint *point, CPointf *pointf)
{
    double ax, bx, ay, by;

    int cx = rect.left + (rect.right - rect.left) / 2;
    int cy = rect.top + (rect.bottom - rect.top) / 2;

	ax = 1 / BASE_UNIT;
    bx = -cx / BASE_UNIT;
    ay = -1 / BASE_UNIT;
    by = cy / BASE_UNIT;

    pointf->x = point->x * ax + bx;
    pointf->y = point->y * ay + by;
}

void CNeuronView::DrawCoordinate(CDC& dc, CRect& rect)
{
    int cx, cy;
    int offset;

    static CPen penCorrdinate(PS_SOLID, 1, RGB(0, 0, 0));

    dc.SelectObject(&penCorrdinate);
    
    cx = rect.left + (rect.right - rect.left) / 2;
    cy = rect.top + (rect.bottom - rect.top) / 2;
    
    /* x-coordinate */
    dc.MoveTo(rect.left, cy);
    dc.LineTo(rect.right, cy);
    offset = BASE_UNIT;
    while (offset + cx < rect.right)
    {
        dc.MoveTo(cx + offset, cy);
        dc.LineTo(cx + offset, cy - 5);

        dc.MoveTo(cx - offset, cy);
        dc.LineTo(cx - offset, cy - 5);
        offset += BASE_UNIT;
    }

    /* y-coordinate */
    dc.MoveTo(cx, rect.top);
    dc.LineTo(cx, rect.bottom);
    offset = BASE_UNIT;
    while (offset + cy < rect.bottom)
    {
        dc.MoveTo(cx, cy + offset);
        dc.LineTo(cx + 5, cy + offset);

        dc.MoveTo(cx, cy - offset);
        dc.LineTo(cx + 5, cy - offset);
        offset += BASE_UNIT;
    }
}
void CNeuronView::DrawLine(CDC &dc, CRect& rect, double w1, double w2, double b)
{
    int cx, cy;

    cx = rect.left + (rect.right - rect.left) / 2;
    cy = rect.top + (rect.bottom - rect.top) / 2;

    if (fabs(w2 - 0.0) < 0.0001)
    {
        if (fabs(w1 - 0.0) < 0.0001)
        {
            EndAdjust();
            MessageBox("Neuron is Dead.", "Error");            
        }
        
        CPointf pTmp1(-b / w1, 0);
        CPoint pTmp2;
        MPtoLP(rect, &pTmp1, &pTmp2);
        dc.MoveTo(pTmp2.x, rect.top);
        dc.LineTo(pTmp2.x, rect.bottom);
    }
    else
    {
        CPointf pf1, pf2;
        CPoint pc1, pc2;
        pf1.x = -(cx - rect.left) / BASE_UNIT;
        pf1.y = -(w1 * pf1.x + b) / w2;
        pf2.x = -pf1.x;
        pf2.y = -(w1 * pf2.x + b) / w2;

        MPtoLP(rect, &pf1, &pc1);
        MPtoLP(rect, &pf2, &pc2);
        
        /* adjust the beginning and ending of the line to be inner the rectangle of view */        
        if (!(pc1.y <= rect.top && pc2.y <= rect.top) 
            && !(pc1.y >= rect.bottom && pc2.y >= rect.bottom))
        {
            if (pc2.y != pc1.y)
            {
                int tmpY;
                tmpY = pc1.y < rect.top ? rect.top : 
                       pc1.y > rect.bottom ? rect.bottom : pc1.y;
                pc1.x = (pc2.x - pc1.x) * (tmpY - pc1.y) / (pc2.y - pc1.y) + pc1.x;
                pc1.y = tmpY;

                tmpY = pc2.y < rect.top ? rect.top : 
                       pc2.y > rect.bottom ? rect.bottom : pc2.y;
                pc2.x = (pc2.x - pc1.x) * (tmpY - pc1.y) / (pc2.y - pc1.y) + pc1.x;
                pc2.y = tmpY;
            }
            dc.MoveTo(pc1);
            dc.LineTo(pc2);
        }
    }
}

LRESULT CNeuronView::OnMouseLeave(WPARAM wParam, LPARAM lParam)
{
    m_bOver = FALSE;
    return 0;
}

LRESULT CNeuronView::OnMouseHover(WPARAM wParam, LPARAM lParam)
{
    m_bOver = TRUE;
    return 0;
}

void CNeuronView::OnMouseMove(UINT nFlags, CPoint point) 
{
	// TODO: Add your message handler code here and/or call default
	TRACKMOUSEEVENT tme;
	tme.cbSize = sizeof(tme);
	tme.hwndTrack = m_hWnd;
	tme.dwFlags = TME_LEAVE | TME_HOVER;
	tme.dwHoverTime = 1;
	_TrackMouseEvent(&tme);
    
	CStatic::OnMouseMove(nFlags, point);
}

void CNeuronView::PreSubclassWindow() 
{	
	CStatic::PreSubclassWindow();
    ModifyStyle(0, BS_OWNERDRAW | SS_NOTIFY);
}

BOOL CNeuronView::OnSetCursor(CWnd* pWnd, UINT nHitTest, UINT message) 
{
	// TODO: Add your message handler code here and/or call default
	if (m_bOver)
    {
        ::SetCursor(m_hCursor);
        return TRUE;
    }
    else
    {
	    return CStatic::OnSetCursor(pWnd, nHitTest, message);
    }
}

void CNeuronView::AddPoint(const CPointf *pPointf, int iPointClass)
{
    m_vecPoint.push_back(*pPointf);
    m_vecPointClass.push_back(iPointClass);
}

void CNeuronView::SetWeightsListener(CWeightsListener *listener)
{
    m_pWeightsListener = listener;
}

void CNeuronView::SetMouseListener(CMouseListener *listener)
{
    m_pMouseListener = listener;
}

void CNeuronView::Redraw()
{
    CRect rect;
    GetClientRect(&rect);
    RedrawWindow(&rect, NULL);
}

void CNeuronView::ResetNeuron(double dbLearnRate, double dbActiveLvl, double dbMomentum)
{
    if (m_pNeuralNetWork)
        delete m_pNeuralNetWork;

    m_pNeuralNetWork = new NeuralNetwork(2);
    m_pNeuralNetWork->appendLayer(2, dbLearnRate, dbActiveLvl, dbMomentum);
    m_pNeuralNetWork->appendLayer(1, dbLearnRate, dbActiveLvl, dbMomentum);

    m_iCurrentPoint = -1;
    WeightsNotify();
}

void CNeuronView::BeginAdjust()
{
    if (m_iCurrentPoint != -1 || m_vecPoint.empty())
        return;

    UINT retVal = SetTimer(m_uiTimerID, m_uiTimeout, NULL);
    if (retVal != m_uiTimerID)
    {
        MessageBox("Timer is not available");
        return;
    }

    m_iCurrentPoint = 0;
}

void CNeuronView::EndAdjust()
{
    if (m_iCurrentPoint != -1)
    {
        m_iCurrentPoint = -1;
        KillTimer(m_uiTimerID);
    }
}

void CNeuronView::WeightsNotify()
{
    if (m_pWeightsListener)
    {
        double weights[9];
        weights[0] = m_pNeuralNetWork->getB(1, 0);
        weights[1] = m_pNeuralNetWork->getWeight(1, 0, 0);
        weights[2] = m_pNeuralNetWork->getWeight(1, 1, 0);
        weights[3] = m_pNeuralNetWork->getB(1, 1);
        weights[4] = m_pNeuralNetWork->getWeight(1, 0, 1);
        weights[5] = m_pNeuralNetWork->getWeight(1, 1, 1);
        weights[6] = m_pNeuralNetWork->getB(2, 0);
        weights[7] = m_pNeuralNetWork->getWeight(2, 0, 0);
        weights[8] = m_pNeuralNetWork->getWeight(2, 1, 0);
        m_pWeightsListener->OnWeightsNotify(weights, 9);
    }

}
void CNeuronView::OnTimer(UINT nIDEvent) 
{
    if (nIDEvent == m_uiTimerID && m_iCurrentPoint != -1 && m_pNeuralNetWork)
    {
        DVec vecInputs;
        DVec vecTarget;
        
        vecInputs.resize(2);
        vecInputs[0] = m_vecPoint[m_iCurrentPoint].x;
        vecInputs[1] = m_vecPoint[m_iCurrentPoint].y;

        vecTarget.resize(1);
        vecTarget[0] = m_vecPointClass[m_iCurrentPoint];

        m_pNeuralNetWork->learn(vecInputs, vecTarget);
    
        m_iCurrentPoint++;
        if (m_iCurrentPoint == m_vecPoint.size())
        {
            m_iCurrentPoint = 0;;
        }
        
        Redraw();

        WeightsNotify();
        
    }
	CStatic::OnTimer(nIDEvent);
}

void CNeuronView::OnLButtonDown(UINT nFlags, CPoint point) 
{
	// TODO: Add your message handler code here and/or call default
    if (m_pMouseListener)
    {
        
    }
	CStatic::OnLButtonDown(nFlags, point);
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -