📄 datatest.cpp
字号:
#include "factory.h"#include "integer.h"#include "filters.h"#include "hex.h"#include "randpool.h"#include "files.h"#include "trunhash.h"#include "queue.h"#include "validate.h"#include <iostream>#include <memory>USING_NAMESPACE(CryptoPP)USING_NAMESPACE(std)typedef std::map<std::string, std::string> TestData;class TestFailure : public Exception{public: TestFailure() : Exception(OTHER_ERROR, "Validation test failed") {}};static const TestData *s_currentTestData = NULL;static void OutputTestData(const TestData &v){ for (TestData::const_iterator i = v.begin(); i != v.end(); ++i) { cerr << i->first << ": " << i->second << endl; }}static void SignalTestFailure(){ OutputTestData(*s_currentTestData); throw TestFailure();}static void SignalTestError(){ OutputTestData(*s_currentTestData); throw Exception(Exception::OTHER_ERROR, "Unexpected error during validation test");}bool DataExists(const TestData &data, const char *name){ TestData::const_iterator i = data.find(name); return (i != data.end());}const std::string & GetRequiredDatum(const TestData &data, const char *name){ TestData::const_iterator i = data.find(name); if (i == data.end()) SignalTestError(); return i->second;}void PutDecodedDatumInto(const TestData &data, const char *name, BufferedTransformation &target){ std::string s1 = GetRequiredDatum(data, name), s2; while (!s1.empty()) { while (s1[0] == ' ') s1 = s1.substr(1); int repeat = 1; if (s1[0] == 'r') { repeat = atoi(s1.c_str()+1); s1 = s1.substr(s1.find(' ')+1); } s2 = ""; // MSVC 6 doesn't have clear(); if (s1[0] == '\"') { s2 = s1.substr(1, s1.find('\"', 1)-1); s1 = s1.substr(s2.length() + 2); } else if (s1.substr(0, 2) == "0x") { StringSource(s1.substr(2, s1.find(' ')), true, new HexDecoder(new StringSink(s2))); s1 = s1.substr(STDMIN(s1.find(' '), s1.length())); } else { StringSource(s1.substr(0, s1.find(' ')), true, new HexDecoder(new StringSink(s2))); s1 = s1.substr(STDMIN(s1.find(' '), s1.length())); } ByteQueue q; while (repeat--) { q.Put((const byte *)s2.data(), s2.size()); if (q.MaxRetrievable() > 4*1024 || repeat == 0) q.TransferTo(target); } }}std::string GetDecodedDatum(const TestData &data, const char *name){ std::string s; PutDecodedDatumInto(data, name, StringSink(s).Ref()); return s;}std::string GetOptionalDecodedDatum(const TestData &data, const char *name){ std::string s; if (DataExists(data, name)) PutDecodedDatumInto(data, name, StringSink(s).Ref()); return s;}class TestDataNameValuePairs : public NameValuePairs{public: TestDataNameValuePairs(const TestData &data) : m_data(data) {} virtual bool GetVoidValue(const char *name, const std::type_info &valueType, void *pValue) const { TestData::const_iterator i = m_data.find(name); if (i == m_data.end()) { if (std::string(name) == Name::DigestSize() && valueType == typeid(int)) { i = m_data.find("MAC"); if (i == m_data.end()) i = m_data.find("Digest"); if (i == m_data.end()) return false; m_temp.resize(0); PutDecodedDatumInto(m_data, i->first.c_str(), StringSink(m_temp).Ref()); *reinterpret_cast<int *>(pValue) = (int)m_temp.size(); return true; } else return false; } const std::string &value = i->second; if (valueType == typeid(int)) *reinterpret_cast<int *>(pValue) = atoi(value.c_str()); else if (valueType == typeid(Integer)) *reinterpret_cast<Integer *>(pValue) = Integer((std::string(value) + "h").c_str()); else if (valueType == typeid(ConstByteArrayParameter)) { m_temp.resize(0); PutDecodedDatumInto(m_data, name, StringSink(m_temp).Ref()); reinterpret_cast<ConstByteArrayParameter *>(pValue)->Assign((const byte *)m_temp.data(), m_temp.size(), false); } else throw ValueTypeMismatch(name, typeid(std::string), valueType); return true; }private: const TestData &m_data; mutable std::string m_temp;};void TestKeyPairValidAndConsistent(CryptoMaterial &pub, const CryptoMaterial &priv){ if (!pub.Validate(GlobalRNG(), 3)) SignalTestFailure(); if (!priv.Validate(GlobalRNG(), 3)) SignalTestFailure();/* EqualityComparisonFilter comparison; pub.Save(ChannelSwitch(comparison, "0")); pub.AssignFrom(priv); pub.Save(ChannelSwitch(comparison, "1")); comparison.ChannelMessageSeriesEnd("0"); comparison.ChannelMessageSeriesEnd("1");*/}void TestSignatureScheme(TestData &v){ std::string name = GetRequiredDatum(v, "Name"); std::string test = GetRequiredDatum(v, "Test"); std::auto_ptr<PK_Signer> signer(ObjectFactoryRegistry<PK_Signer>::Registry().CreateObject(name.c_str())); std::auto_ptr<PK_Verifier> verifier(ObjectFactoryRegistry<PK_Verifier>::Registry().CreateObject(name.c_str())); TestDataNameValuePairs pairs(v); std::string keyFormat = GetRequiredDatum(v, "KeyFormat"); if (keyFormat == "DER") verifier->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref()); else if (keyFormat == "Component") verifier->AccessMaterial().AssignFrom(pairs); if (test == "Verify" || test == "NotVerify") { VerifierFilter verifierFilter(*verifier, NULL, VerifierFilter::SIGNATURE_AT_BEGIN); PutDecodedDatumInto(v, "Signature", verifierFilter); PutDecodedDatumInto(v, "Message", verifierFilter); verifierFilter.MessageEnd(); if (verifierFilter.GetLastResult() == (test == "NotVerify")) SignalTestFailure(); } else if (test == "PublicKeyValid") { if (!verifier->GetMaterial().Validate(GlobalRNG(), 3)) SignalTestFailure(); } else goto privateKeyTests; return;privateKeyTests: if (keyFormat == "DER") signer->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref()); else if (keyFormat == "Component") signer->AccessMaterial().AssignFrom(pairs); if (test == "KeyPairValidAndConsistent") { TestKeyPairValidAndConsistent(verifier->AccessMaterial(), signer->GetMaterial()); } else if (test == "Sign") { SignerFilter f(GlobalRNG(), *signer, new HexEncoder(new FileSink(cout))); StringSource ss(GetDecodedDatum(v, "Message"), true, new Redirector(f)); SignalTestFailure(); } else if (test == "DeterministicSign") { SignalTestError(); assert(false); // TODO: implement } else if (test == "RandomSign") { SignalTestError(); assert(false); // TODO: implement } else if (test == "GenerateKey") { SignalTestError(); assert(false); } else { SignalTestError(); assert(false); }}void TestAsymmetricCipher(TestData &v){ std::string name = GetRequiredDatum(v, "Name"); std::string test = GetRequiredDatum(v, "Test"); std::auto_ptr<PK_Encryptor> encryptor(ObjectFactoryRegistry<PK_Encryptor>::Registry().CreateObject(name.c_str())); std::auto_ptr<PK_Decryptor> decryptor(ObjectFactoryRegistry<PK_Decryptor>::Registry().CreateObject(name.c_str())); std::string keyFormat = GetRequiredDatum(v, "KeyFormat"); if (keyFormat == "DER") { decryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref()); encryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref()); } else if (keyFormat == "Component") { TestDataNameValuePairs pairs(v); decryptor->AccessMaterial().AssignFrom(pairs); encryptor->AccessMaterial().AssignFrom(pairs); } if (test == "DecryptMatch") { std::string decrypted, expected = GetDecodedDatum(v, "Plaintext"); StringSource ss(GetDecodedDatum(v, "Ciphertext"), true, new PK_DecryptorFilter(GlobalRNG(), *decryptor, new StringSink(decrypted))); if (decrypted != expected) SignalTestFailure(); } else if (test == "KeyPairValidAndConsistent") { TestKeyPairValidAndConsistent(encryptor->AccessMaterial(), decryptor->GetMaterial()); } else { SignalTestError(); assert(false); }}void TestSymmetricCipher(TestData &v, const NameValuePairs &overrideParameters){ std::string name = GetRequiredDatum(v, "Name"); std::string test = GetRequiredDatum(v, "Test"); std::string key = GetDecodedDatum(v, "Key"); std::string plaintext = GetDecodedDatum(v, "Plaintext"); TestDataNameValuePairs testDataPairs(v); CombinedNameValuePairs pairs(overrideParameters, testDataPairs); if (test == "Encrypt" || test == "EncryptXorDigest" || test == "Resync") { static member_ptr<SymmetricCipher> encryptor, decryptor; static std::string lastName; if (name != lastName) { encryptor.reset(ObjectFactoryRegistry<SymmetricCipher, ENCRYPTION>::Registry().CreateObject(name.c_str())); decryptor.reset(ObjectFactoryRegistry<SymmetricCipher, DECRYPTION>::Registry().CreateObject(name.c_str())); lastName = name; } ConstByteArrayParameter iv; if (pairs.GetValue(Name::IV(), iv) && iv.size() != encryptor->IVSize()) SignalTestFailure(); if (test == "Resync") { encryptor->Resynchronize(iv.begin(), (int)iv.size()); decryptor->Resynchronize(iv.begin(), (int)iv.size()); } else { encryptor->SetKey((const byte *)key.data(), key.size(), pairs); decryptor->SetKey((const byte *)key.data(), key.size(), pairs); } int seek = pairs.GetIntValueWithDefault("Seek", 0); if (seek) { encryptor->Seek(seek); decryptor->Seek(seek); } std::string encrypted, xorDigest, ciphertext, ciphertextXorDigest; StringSource ss(plaintext, false, new StreamTransformationFilter(*encryptor, new StringSink(encrypted), StreamTransformationFilter::NO_PADDING)); ss.Pump(plaintext.size()/2 + 1); ss.PumpAll(); /*{ std::string z; encryptor->Seek(seek); StringSource ss(plaintext, false, new StreamTransformationFilter(*encryptor, new StringSink(z), StreamTransformationFilter::NO_PADDING)); while (ss.Pump(64)) {}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -