📄 datatest.cpp
字号:
#include "factory.h"
#include "integer.h"
#include "filters.h"
#include "hex.h"
#include "randpool.h"
#include "files.h"
#include "trunhash.h"
#include "queue.h"
#include "validate.h"
#include <iostream>
#include <memory>
USING_NAMESPACE(CryptoPP)
USING_NAMESPACE(std)
typedef std::map<std::string, std::string> TestData;
class TestFailure : public Exception
{
public:
TestFailure() : Exception(OTHER_ERROR, "Validation test failed") {}
};
static const TestData *s_currentTestData = NULL;
static void OutputTestData(const TestData &v)
{
for (TestData::const_iterator i = v.begin(); i != v.end(); ++i)
{
cerr << i->first << ": " << i->second << endl;
}
}
static void SignalTestFailure()
{
OutputTestData(*s_currentTestData);
throw TestFailure();
}
static void SignalTestError()
{
OutputTestData(*s_currentTestData);
throw Exception(Exception::OTHER_ERROR, "Unexpected error during validation test");
}
const std::string & GetRequiredDatum(const TestData &data, const char *name)
{
TestData::const_iterator i = data.find(name);
if (i == data.end())
SignalTestError();
return i->second;
}
void PutDecodedDatumInto(const TestData &data, const char *name, BufferedTransformation &target)
{
std::string s1 = GetRequiredDatum(data, name), s2;
while (!s1.empty())
{
while (s1[0] == ' ')
s1 = s1.substr(1);
int repeat = 1;
if (s1[0] == 'r')
{
repeat = atoi(s1.c_str()+1);
s1 = s1.substr(s1.find(' ')+1);
}
s2 = ""; // MSVC 6 doesn't have clear();
if (s1[0] == '\"')
{
s2 = s1.substr(1, s1.find('\"', 1)-1);
s1 = s1.substr(s2.length() + 2);
}
else if (s1.substr(0, 2) == "0x")
{
StringSource(s1.substr(2, s1.find(' ')), true, new HexDecoder(new StringSink(s2)));
s1 = s1.substr(STDMIN(s1.find(' '), s1.length()));
}
else
{
StringSource(s1.substr(0, s1.find(' ')), true, new HexDecoder(new StringSink(s2)));
s1 = s1.substr(STDMIN(s1.find(' '), s1.length()));
}
ByteQueue q;
while (repeat--)
{
q.Put((const byte *)s2.data(), s2.size());
if (q.MaxRetrievable() > 4*1024 || repeat == 0)
q.TransferTo(target);
}
}
}
std::string GetDecodedDatum(const TestData &data, const char *name)
{
std::string s;
PutDecodedDatumInto(data, name, StringSink(s).Ref());
return s;
}
class TestDataNameValuePairs : public NameValuePairs
{
public:
TestDataNameValuePairs(const TestData &data) : m_data(data) {}
virtual bool GetVoidValue(const char *name, const std::type_info &valueType, void *pValue) const
{
TestData::const_iterator i = m_data.find(name);
if (i == m_data.end())
return false;
const std::string &value = i->second;
if (valueType == typeid(int))
*reinterpret_cast<int *>(pValue) = atoi(value.c_str());
else if (valueType == typeid(Integer))
*reinterpret_cast<Integer *>(pValue) = Integer((std::string(value) + "h").c_str());
else if (valueType == typeid(ConstByteArrayParameter))
{
m_temp.resize(0);
PutDecodedDatumInto(m_data, name, StringSink(m_temp).Ref());
reinterpret_cast<ConstByteArrayParameter *>(pValue)->Assign((const byte *)m_temp.data(), m_temp.size(), true);
}
else if (valueType == typeid(const byte *))
{
m_temp.resize(0);
PutDecodedDatumInto(m_data, name, StringSink(m_temp).Ref());
*reinterpret_cast<const byte * *>(pValue) = (const byte *)m_temp.data();
}
else
throw ValueTypeMismatch(name, typeid(std::string), valueType);
return true;
}
private:
const TestData &m_data;
mutable std::string m_temp;
};
void TestKeyPairValidAndConsistent(CryptoMaterial &pub, const CryptoMaterial &priv)
{
if (!pub.Validate(GlobalRNG(), 3))
SignalTestFailure();
if (!priv.Validate(GlobalRNG(), 3))
SignalTestFailure();
/* EqualityComparisonFilter comparison;
pub.Save(ChannelSwitch(comparison, "0"));
pub.AssignFrom(priv);
pub.Save(ChannelSwitch(comparison, "1"));
comparison.ChannelMessageSeriesEnd("0");
comparison.ChannelMessageSeriesEnd("1");
*/
}
void TestSignatureScheme(TestData &v)
{
std::string name = GetRequiredDatum(v, "Name");
std::string test = GetRequiredDatum(v, "Test");
std::auto_ptr<PK_Signer> signer(ObjectFactoryRegistry<PK_Signer>::Registry().CreateObject(name.c_str()));
std::auto_ptr<PK_Verifier> verifier(ObjectFactoryRegistry<PK_Verifier>::Registry().CreateObject(name.c_str()));
TestDataNameValuePairs pairs(v);
std::string keyFormat = GetRequiredDatum(v, "KeyFormat");
if (keyFormat == "DER")
verifier->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref());
else if (keyFormat == "Component")
verifier->AccessMaterial().AssignFrom(pairs);
if (test == "Verify" || test == "NotVerify")
{
VerifierFilter verifierFilter(*verifier, NULL, VerifierFilter::SIGNATURE_AT_BEGIN);
PutDecodedDatumInto(v, "Signature", verifierFilter);
PutDecodedDatumInto(v, "Message", verifierFilter);
verifierFilter.MessageEnd();
if (verifierFilter.GetLastResult() == (test == "NotVerify"))
SignalTestFailure();
}
else if (test == "PublicKeyValid")
{
if (!verifier->GetMaterial().Validate(GlobalRNG(), 3))
SignalTestFailure();
}
else
goto privateKeyTests;
return;
privateKeyTests:
if (keyFormat == "DER")
signer->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref());
else if (keyFormat == "Component")
signer->AccessMaterial().AssignFrom(pairs);
if (test == "KeyPairValidAndConsistent")
{
TestKeyPairValidAndConsistent(verifier->AccessMaterial(), signer->GetMaterial());
}
else if (test == "Sign")
{
SignerFilter f(GlobalRNG(), *signer, new HexEncoder(new FileSink(cout)));
StringSource ss(GetDecodedDatum(v, "Message"), true, new Redirector(f));
SignalTestFailure();
}
else if (test == "DeterministicSign")
{
SignalTestError();
assert(false); // TODO: implement
}
else if (test == "RandomSign")
{
SignalTestError();
assert(false); // TODO: implement
}
else if (test == "GenerateKey")
{
SignalTestError();
assert(false);
}
else
{
SignalTestError();
assert(false);
}
}
void TestAsymmetricCipher(TestData &v)
{
std::string name = GetRequiredDatum(v, "Name");
std::string test = GetRequiredDatum(v, "Test");
std::auto_ptr<PK_Encryptor> encryptor(ObjectFactoryRegistry<PK_Encryptor>::Registry().CreateObject(name.c_str()));
std::auto_ptr<PK_Decryptor> decryptor(ObjectFactoryRegistry<PK_Decryptor>::Registry().CreateObject(name.c_str()));
std::string keyFormat = GetRequiredDatum(v, "KeyFormat");
if (keyFormat == "DER")
{
decryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref());
encryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref());
}
else if (keyFormat == "Component")
{
TestDataNameValuePairs pairs(v);
decryptor->AccessMaterial().AssignFrom(pairs);
encryptor->AccessMaterial().AssignFrom(pairs);
}
if (test == "DecryptMatch")
{
std::string decrypted, expected = GetDecodedDatum(v, "Plaintext");
StringSource ss(GetDecodedDatum(v, "Ciphertext"), true, new PK_DecryptorFilter(GlobalRNG(), *decryptor, new StringSink(decrypted)));
if (decrypted != expected)
SignalTestFailure();
}
else if (test == "KeyPairValidAndConsistent")
{
TestKeyPairValidAndConsistent(encryptor->AccessMaterial(), decryptor->GetMaterial());
}
else
{
SignalTestError();
assert(false);
}
}
void TestSymmetricCipher(TestData &v)
{
std::string name = GetRequiredDatum(v, "Name");
std::string test = GetRequiredDatum(v, "Test");
std::string key = GetDecodedDatum(v, "Key");
std::string plaintext = GetDecodedDatum(v, "Plaintext");
TestDataNameValuePairs pairs(v);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -