⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 lockfree.cpp

📁 游戏编程精粹6第1章 通用编程,对入门的游戏开发者很有帮助.
💻 CPP
📖 第 1 页 / 共 2 页
字号:
    {
        std::cout << "CAS is INCORRECT." << std::endl;
    }
    else
    {
        pNode = &oldVal;
        if(!CAS_intrinsic(&pNode, &oldVal, &newVal))
        {
            std::cout << "CAS is INCORRECT." << std::endl;
        }
        else if(pNode != &newVal)
        {
            std::cout << "CAS is INCORRECT." << std::endl;
        }
        else
        {
            std::cout << "CAS is correct." << std::endl;
        }
    }
}
#endif

//
// Verify Windows API version of CAS.
//
void Test_CAS_windows()
{
    std::cout << "Testing CAS_windows...";

    node<MyStruct> oldVal;
    node<MyStruct> newVal;
    node<MyStruct> * pNode = &newVal;
    if(CAS_windows(&pNode, &oldVal, &newVal))
    {
        std::cout << "CAS is INCORRECT." << std::endl;
    }
    else
    {
        pNode = &oldVal;
        if(!CAS_windows(&pNode, &oldVal, &newVal))
        {
            std::cout << "CAS is INCORRECT." << std::endl;
        }
        else if(pNode != &newVal)
        {
            std::cout << "CAS is INCORRECT." << std::endl;
        }
        else
        {
            std::cout << "CAS is correct." << std::endl;
        }
    }
}

template<typename T>
struct CAS2Test
{
    node<T> * pNode;
    uint32_t  tag;
    CAS2Test(node<T> * pnewNode, uint32_t newTag) : pNode(pnewNode), tag(newTag) {}
};

//
// Verify Assembly version of CAS.
//
void Test_CAS2_assembly()
{
    std::cout << "Testing CAS2_assembly...";

    node<MyStruct> oldVal;
    node<MyStruct> newVal;

    CAS2Test<MyStruct> myStruct(&newVal, 0xABCD);

    if(CAS2_assembly(&myStruct.pNode, &oldVal, 0xABCD, &newVal, 0xAAAA))
    {
        // should not succeed if pointers don't match
        std::cout << "CAS2 is INCORRECT." << std::endl;
    }
    else if(CAS2_assembly(&myStruct.pNode, &newVal, 0xAAAA, &oldVal, 0xABCD))
    {
        // should not succeed if tags don't match
        std::cout << "CAS2 is INCORRECT." << std::endl;
    }
    else
    {
        myStruct.pNode = &oldVal;
        if(!CAS2_assembly(&myStruct.pNode, &oldVal, 0xABCD, &newVal, 0xAAAA))
        {
            std::cout << "CAS2 is INCORRECT." << std::endl;
        }
        else if(myStruct.pNode != &newVal)
        {
            std::cout << "CAS2 is INCORRECT." << std::endl;
        }
        else if(myStruct.tag != 0xAAAA)
        {
            std::cout << "CAS2 is INCORRECT." << std::endl;
        }
        else
        {
            std::cout << "CAS2 is correct." << std::endl;
        }
    }
}

//
// Verify compiler intrinsic version of CAS.
//
#ifdef _MSC_VER
void Test_CAS2_intrinsic()
{
    std::cout << "Testing CAS2_intrinsic...";

    node<MyStruct> oldVal;
    node<MyStruct> newVal;

    CAS2Test<MyStruct> myStruct(&newVal, 0xABCD);

    if(CAS2_intrinsic(&myStruct.pNode, &oldVal, 0xABCD, &newVal, 0xAAAA))
    {
        // should not succeed if pointers don't match
        std::cout << "CAS2 is INCORRECT." << std::endl;
    }
    else if(CAS2_intrinsic(&myStruct.pNode, &newVal, 0xAAAA, &oldVal, 0xABCD))
    {
        // should not succeed if tags don't match
        std::cout << "CAS2 is INCORRECT." << std::endl;
    }
    else
    {
        myStruct.pNode = &oldVal;
        if(!CAS2_intrinsic(&myStruct.pNode, &oldVal, 0xABCD, &newVal, 0xAAAA))
        {
            std::cout << "CAS2 is INCORRECT." << std::endl;
        }
        else if(myStruct.pNode != &newVal)
        {
            std::cout << "CAS2 is INCORRECT." << std::endl;
        }
        else if(myStruct.tag != 0xAAAA)
        {
            std::cout << "CAS2 is INCORRECT." << std::endl;
        }
        else
        {
            std::cout << "CAS2 is correct." << std::endl;
        }
    }
}
#endif

//
// Verify Windows API version of CAS2.
//
#if WINVER >= 0x0600
void Test_CAS2_windows()
{
    std::cout << "Testing CAS2_windows...";

    node<MyStruct> oldVal;
    node<MyStruct> newVal;

    CAS2Test<MyStruct> myStruct(&newVal, 0xABCD);

    if(CAS2_windows(&myStruct.pNode, &oldVal, 0xABCD, &newVal, 0xAAAA))
    {
        // should not succeed if pointers don't match
        std::cout << "CAS2 is INCORRECT." << std::endl;
    }
    else if(CAS2_windows(&myStruct.pNode, &newVal, 0xAAAA, &oldVal, 0xABCD))
    {
        // should not succeed if tags don't match
        std::cout << "CAS2 is INCORRECT." << std::endl;
    }
    else
    {
        myStruct.pNode = &oldVal;
        if(!CAS2_windows(&myStruct.pNode, &oldVal, 0xABCD, &newVal, 0xAAAA))
        {
            std::cout << "CAS2 is INCORRECT." << std::endl;
        }
        else if(myStruct.pNode != &newVal)
        {
            std::cout << "CAS2 is INCORRECT." << std::endl;
        }
        else if(myStruct.tag != 0xAAAA)
        {
            std::cout << "CAS2 is INCORRECT." << std::endl;
        }
        else
        {
            std::cout << "CAS2 is correct." << std::endl;
        }
    }
}
#endif  // WINVER >= 0x0600

void HandleWait(HANDLE & hThread)
{
    WaitForSingleObject(hThread, INFINITE);
    CloseHandle(hThread);
}

template<typename T>
void CreateNode(node<T> * & pNode)
{
    pNode = new node<T>;
}

template<typename T>
void DeleteNode(node<T> * & pNode)
{
    delete pNode;
}

//
// Stress the multithreaded stack code.
//
template<typename T, int NUMTHREADS>
class StressStack
{
    LockFreeStack<T> _stack;

    static const unsigned int cNodes = 100;    // nodes per thread

    struct ThreadData
    {
        StressStack<T, NUMTHREADS> * pStress;
        DWORD thread_num;
    };

    std::vector<ThreadData> _aThreadData;
    std::vector<node<T> *> _apNodes;

public:
    StressStack() : _aThreadData(NUMTHREADS), _apNodes(cNodes * NUMTHREADS) {}

    //
    // The stack stress will spawn a number of threads (4096 in our tests), each of which will
    // push and pop nodes onto a single stack.  We expect that no access violations will occur
    // and that the stack is empty upon completion.
    //
    void operator()()
    {
        std::cout << "Running Stack Stress..." << std::endl;

        //
        // Create all of the nodes.
        //
        std::for_each(_apNodes.begin(), _apNodes.end(), CreateNode<T>);

        unsigned int ii;
        for(ii = 0; ii < _aThreadData.size(); ++ii)
        {
            _aThreadData[ii].pStress = this;
            _aThreadData[ii].thread_num = ii;
        }

        std::vector<HANDLE> aHandles(NUMTHREADS);
        for(ii = 0; ii < aHandles.size(); ++ii)
        {
            unsigned int tid;
            aHandles[ii] = (HANDLE)_beginthreadex(NULL, 0, StackThreadFunc, &_aThreadData[ii], 0, &tid);
        }

        //
        // Wait for the threads to exit.
        //
        std::for_each(aHandles.begin(), aHandles.end(), HandleWait);

        //
        // Delete all of the nodes.
        //
        std::for_each(_apNodes.begin(), _apNodes.end(), DeleteNode<T>);

        //
        // Ideas for improvement:
        //  We could verify that there is a 1-1 mapping between values pushed and values popped.
        //  Verify the count of pops in the stack matches the number of pops for each thread.
        //
    } // void operator()()

    static unsigned int __stdcall StackThreadFunc(void * pv)
    {
        unsigned int tid = GetCurrentThreadId();
        ThreadData * ptd = reinterpret_cast<ThreadData *>(pv);
        if(FULL_TRACE)
        {
            std::cout << tid << " adding" << std::endl;
        }

        unsigned int ii;
        for(ii = 0; ii < cNodes; ++ii)
        {
            ptd->pStress->_stack.Push(ptd->pStress->_apNodes[ptd->thread_num * cNodes + ii]);
        }

        if(FULL_TRACE)
        {
            std::cout << tid << " removing" << std::endl;
        }

        for(ii = 0; ii < cNodes; ++ii)
        {
            ptd->pStress->_apNodes[ptd->thread_num * cNodes + ii] = ptd->pStress->_stack.Pop();
        }

        return 0;
    }
};  // class StressStack

//
// Stress the multithreaded queue code.
//
template<typename T, int NUMTHREADS>
class StressQueue
{
    LockFreeQueue<T> _queue;

    struct ThreadData
    {
        StressQueue<T, NUMTHREADS> * pStress;
        DWORD thread_num;
    };

    std::vector<ThreadData> _aThreadData;
    std::vector<node<T> *> & _apNodes;

public:
    static const unsigned int cNodes = 100;     // nodes per thread

    StressQueue(std::vector<node<T> *> & apNodes) : _queue(apNodes[0]), _aThreadData(NUMTHREADS), _apNodes(apNodes) {}

    //
    // The queue stress will spawn a number of threads (4096 in our tests), each of which will
    // add and remove nodes on a single queue.  We expect that no access violations will occur
    // and that the queue is empty (except for the dummy node) upon completion.
    //
    void operator()()
    {
        std::cout << "Running Queue Stress..." << std::endl;

        unsigned int ii;
        for(ii = 0; ii < _aThreadData.size(); ++ii)
        {
            _aThreadData[ii].pStress = this;
            _aThreadData[ii].thread_num = ii;
        }

        std::vector<HANDLE> aHandles(NUMTHREADS);
        for(ii = 0; ii < aHandles.size(); ++ii)
        {
            unsigned int tid;
            aHandles[ii] = (HANDLE)_beginthreadex(NULL, 0, QueueThreadFunc, &_aThreadData[ii], 0, &tid);
        }

        //
        // Wait for the threads to exit.
        //
        std::for_each(aHandles.begin(), aHandles.end(), HandleWait);

        //
        // Ideas for improvement:
        //  We could verify that there is a 1-1 mapping between values added and values removed.
        //  Verify the count of pops in the queue matches the number of pops for each thread.
        //
    } // void operator()()

    static unsigned int __stdcall QueueThreadFunc(void * pv)
    {
        unsigned int tid = GetCurrentThreadId();
        ThreadData * ptd = reinterpret_cast<ThreadData *>(pv);
        if(FULL_TRACE)
        {
            std::cout << tid << " adding" << std::endl;
        }

        unsigned int ii;
        for(ii = 0; ii < cNodes; ++ii)
        {
            ptd->pStress->_queue.Add(ptd->pStress->_apNodes[ptd->thread_num * cNodes + ii + 1]);
        }

        if(FULL_TRACE)
        {
            std::cout << tid << " removing" << std::endl;
        }

        for(ii = 0; ii < cNodes; ++ii)
        {
            ptd->pStress->_queue.Remove();
        }

        return 0;
    }
};  // class StressQueue

//
// Demonstrate the lock-free freelist.
// The freelist is based off of ideas found in the freelist article in Game
// Programming Gems 4 by Paul Glinker.  Other ideas for improvement can be
// found in the freelist article in Game Programming Gems 5 by Nathan Mefford.
//
void Demo_Freelist()
{
    std::cout << "Demo of Freelist...";

    //
    // Create a Freelist of MyStructs with 10 elements.
    //
    LockFreeFreeList<MyStruct> fl(10);

    //
    // Allocate a new MyStruct object.
    //
    MyStruct * pStruct = fl.NewInstance();

    //
    // Destroy the MyStruct object and return it to the Freelist.
    //
    fl.FreeInstance(pStruct);

    std::cout << "done" << std::endl;
}

//
// Nothing of importance happens here. :)
//
int main(int, char **)
{
    //
    // Test CAS and CAS2
    //
    Test_CAS_assembly();
#ifdef _MSC_VER
    Test_CAS_intrinsic();
#endif
    Test_CAS_windows();

    Test_CAS2_assembly();
#ifdef _MSC_VER
    Test_CAS2_intrinsic();
#endif
#if WINVER >= 0x0600
    Test_CAS2_windows();
#endif

    //
    // Test Lock-free Stack
    //
    StressStack<TEST_TYPE, 4096>()();

    // Demo the stack
    node<MyStruct> Nodes[10];

    LockFreeStack<MyStruct> stack;
    stack.Push(&Nodes[1]);
    stack.Pop();        // returns &Nodes[1]
    stack.Pop();        // returns NULL

    //
    // Test Lock-free Queue
    //
    std::vector<node<TEST_TYPE> *> apNodes(StressQueue<TEST_TYPE, 4096>::cNodes * 4096 + 1);    // 4096 threads, and 1 extra dummy node
    std::for_each(apNodes.begin(), apNodes.end(), CreateNode<TEST_TYPE>);

    StressQueue<TEST_TYPE, 4096> theQueue(apNodes);
    theQueue();

    std::for_each(apNodes.begin(), apNodes.end(), DeleteNode<TEST_TYPE>);

    // Demo the queue
    LockFreeQueue<MyStruct> queue(&Nodes[0]);   // Nodes[0] is dummy node

    queue.Add(&Nodes[1]);
    queue.Remove();     // returns &Nodes[1]
    queue.Remove();     // returns NULL;

    //
    // Demonstrate Lock-free Freelist
    //
    Demo_Freelist();

    return 0;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -