Kiwi/tools/Evaluator.cpp

150 lines
3.9 KiB
C++

#include <fstream>
#include <iostream>
#include <kiwi/Utils.h>
#include "../src/StrUtils.h"
#include "Evaluator.h"
#include "LCS.hpp"
using namespace std;
using namespace kiwi;
TokenInfo parseWordPOS(const u16string& str)
{
auto p = str.rfind('/');
if (p == str.npos) return {};
u16string form = replace(nonstd::u16string_view(str.data(), p), u"_", u" ");
if (str[p + 1] == 'E')
{
if (form[0] == u'' || form[0] == u'') form[0] = u'';
if (form[0] == u'' || form[0] == u'') form[0] = u'';
}
switch (form[0])
{
case u'\u3134': // ㄴ
form[0] = u'\u11AB'; break;
case u'\u3139': // ㄹ
form[0] = u'\u11AF'; break;
case u'\u3141': // ㅁ
form[0] = u'\u11B7'; break;
case u'\u3142': // ㅂ
form[0] = u'\u11B8'; break;
}
u16string tagStr = str.substr(p + 1);
if (tagStr.find('-') != tagStr.npos)
{
tagStr.erase(tagStr.begin() + tagStr.find('-'), tagStr.end());
}
POSTag tag = toPOSTag(tagStr);
if (tag >= POSTag::max) throw runtime_error{ "Wrong Input '" + utf16To8(str.substr(p + 1)) + "'" };
return { form, tag, 0, 0 };
}
Evaluator::Evaluator(const std::string& testSetFile, const Kiwi* _kw, Match _matchOption, size_t _topN)
: kw{ _kw }, matchOption{ _matchOption }, topN{ _topN }
{
ifstream f{ testSetFile };
if (!f) throw std::ios_base::failure{ "Cannot open '" + testSetFile + "'" };
string line;
while (getline(f, line))
{
while (line.back() == '\n' || line.back() == '\r') line.pop_back();
auto wstr = utf8To16(line);
auto fd = split(wstr, u'\t');
if (fd.size() < 2) continue;
vector<u16string> tokens;
for (size_t i = 1; i < fd.size(); ++i)
{
for (auto s : split(fd[i], u' ')) tokens.emplace_back(s.to_string());
}
TestResult tr;
tr.q = fd[0].to_string();
for (auto& t : tokens) tr.a.emplace_back(parseWordPOS(t));
testsets.emplace_back(std::move(tr));
}
}
void Evaluator::run()
{
for (auto& tr : testsets)
{
auto cands = kw->analyze(tr.q, topN, matchOption);
tr.r = cands[0].first;
}
}
Evaluator::Score Evaluator::evaluate()
{
errors.clear();
size_t totalCount = 0, microCorrect = 0, microCount = 0;
double totalScore = 0;
for (auto& tr : testsets)
{
if (tr.a != tr.r)
{
auto diff = lcs::getDiff(tr.r.begin(), tr.r.end(), tr.a.begin(), tr.a.end(), [](const TokenInfo& a, const TokenInfo& b)
{
if (clearIrregular(a.tag) != clearIrregular(b.tag)) return false;
if (a.tag == POSTag::jko) return true;
if (a.str == u"" && u"" == b.str) return true;
if (b.str == u"" && u"" == a.str) return true;
if (a.str == u"" && u"" == b.str) return true;
if (b.str == u"" && u"" == a.str) return true;
if (a.str == u"" && u"" == b.str) return true;
if (b.str == u"" && u"" == a.str) return true;
if (a.str == u"그것" && u"그거" == b.str) return true;
if (b.str == u"그것" && u"그거" == a.str) return true;
if (a.str == u"" && u"" == b.str) return true;
if (b.str == u"" && u"" == a.str) return true;
return a.str == b.str;
});
size_t common = 0;
for (auto&& d : diff)
{
if (d.first < 0) tr.dr.emplace_back(d.second);
else if (d.first > 0) tr.da.emplace_back(d.second);
else common++;
}
tr.score = common / (double)diff.size();
totalScore += tr.score;
microCorrect += common;
microCount += diff.size();
errors.emplace_back(tr);
}
else
{
totalScore += 1;
microCorrect += tr.r.size();
microCount += tr.r.size();
}
totalCount++;
}
Score ret;
ret.micro = microCorrect / (double)microCount;
ret.macro = totalScore / totalCount;
ret.totalCount = totalCount;
return ret;
}
ostream& operator<<(ostream& o, const kiwi::TokenInfo& t)
{
return o << utf16To8(t.str) << "/" << kiwi::tagToString(t.tag);
}
void Evaluator::TestResult::writeResult(ostream& out) const
{
out << utf16To8(q) << '\t' << score << endl;
for (auto& _r : da)
{
out << _r << '\t';
}
out << endl;
for (auto& _r : dr)
{
out << _r << '\t';
}
out << endl;
out << endl;
}