Add StreamProvider support to C API, Java, and WASM bindings

Co-authored-by: bab2min <19266222+bab2min@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot] 2025-09-11 17:03:41 +00:00 committed by bab2min
commit 0c42704de1
6 changed files with 428 additions and 0 deletions

View file

@ -3,6 +3,7 @@
#include <kiwi/Kiwi.h>
#include <kiwi/Joiner.h>
#include <sstream>
struct Sentence
{
@ -256,6 +257,13 @@ namespace jni
struct ValueBuilder<kiwi::PretokenizedSpan> : public ValueBuilder<decltype(gClsPretokenizedSpan)>
{
};
// Forward declaration for StreamProvider interface
template<>
struct JClassName<jobject>
{
static constexpr auto value = std::string_view{ "kr/pe/bab2min/KiwiBuilder$StreamProvider" };
};
}
class JKiwi;
@ -523,11 +531,134 @@ public:
class JKiwiBuilder : public kiwi::KiwiBuilder, jni::JObject<JKiwiBuilder>
{
private:
JavaVM* jvm = nullptr;
jobject streamProviderGlobalRef = nullptr;
public:
static constexpr std::string_view className = "kr/pe/bab2min/KiwiBuilder";
using kiwi::KiwiBuilder::KiwiBuilder;
// Custom constructor for StreamProvider
JKiwiBuilder(jobject streamProvider, size_t numThreads, kiwi::BuildOption options, kiwi::ModelType modelType)
: KiwiBuilder(createStreamProviderWrapper(streamProvider), numThreads, options, modelType)
{
}
private:
kiwi::KiwiBuilder::StreamProvider createStreamProviderWrapper(jobject streamProvider)
{
JNIEnv* env = getCurrentEnv();
jvm = getJVM();
streamProviderGlobalRef = env->NewGlobalRef(streamProvider);
return [this](const std::string& filename) -> std::unique_ptr<std::istream>
{
JNIEnv* env = nullptr;
bool shouldDetach = false;
// Get JNIEnv for current thread
jint getEnvResult = jvm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_8);
if (getEnvResult == JNI_EDETACHED)
{
if (jvm->AttachCurrentThread(reinterpret_cast<void**>(&env), nullptr) != JNI_OK)
{
return nullptr;
}
shouldDetach = true;
}
else if (getEnvResult != JNI_OK)
{
return nullptr;
}
try
{
// Get StreamProvider.provide method
jclass streamProviderClass = env->FindClass("kr/pe/bab2min/KiwiBuilder$StreamProvider");
jmethodID provideMethod = env->GetMethodID(streamProviderClass, "provide", "(Ljava/lang/String;)Ljava/io/InputStream;");
// Convert filename to Java string
jstring jFilename = env->NewStringUTF(filename.c_str());
// Call provide method
jobject inputStream = env->CallObjectMethod(streamProviderGlobalRef, provideMethod, jFilename);
if (!inputStream || env->ExceptionCheck())
{
if (env->ExceptionCheck()) env->ExceptionClear();
if (shouldDetach) jvm->DetachCurrentThread();
return nullptr;
}
// Read the InputStream into a byte array
jclass inputStreamClass = env->FindClass("java/io/InputStream");
jmethodID availableMethod = env->GetMethodID(inputStreamClass, "available", "()I");
jmethodID readMethod = env->GetMethodID(inputStreamClass, "read", "([B)I");
jmethodID closeMethod = env->GetMethodID(inputStreamClass, "close", "()V");
jint available = env->CallIntMethod(inputStream, availableMethod);
if (available <= 0) available = 1024 * 1024; // Default to 1MB if available() returns 0
jbyteArray byteArray = env->NewByteArray(available);
std::vector<char> buffer;
int totalRead = 0;
while (true)
{
jint bytesRead = env->CallIntMethod(inputStream, readMethod, byteArray);
if (bytesRead <= 0) break;
jbyte* bytes = env->GetByteArrayElements(byteArray, nullptr);
buffer.insert(buffer.end(), reinterpret_cast<char*>(bytes), reinterpret_cast<char*>(bytes + bytesRead));
env->ReleaseByteArrayElements(byteArray, bytes, JNI_ABORT);
totalRead += bytesRead;
}
env->CallVoidMethod(inputStream, closeMethod);
if (shouldDetach) jvm->DetachCurrentThread();
// Create string stream from buffer
std::string data(buffer.begin(), buffer.end());
return std::make_unique<std::istringstream>(std::move(data));
}
catch (...)
{
if (shouldDetach) jvm->DetachCurrentThread();
return nullptr;
}
};
}
JNIEnv* getCurrentEnv()
{
JNIEnv* env = nullptr;
JavaVM* vm = getJVM();
vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_8);
return env;
}
JavaVM* getJVM()
{
// This should be set by the JNI framework - we'll access it via the module
JavaVM* vm = nullptr;
jsize vmCount;
JNI_GetCreatedJavaVMs(&vm, 1, &vmCount);
return vm;
}
public:
~JKiwiBuilder()
{
if (streamProviderGlobalRef)
{
JNIEnv* env = getCurrentEnv();
if (env) env->DeleteGlobalRef(streamProviderGlobalRef);
}
}
bool addWord(const std::u16string& form, kiwi::POSTag tag, float score)
{
return KiwiBuilder::addWord(form, tag, score).second;
@ -581,6 +712,7 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void* reserved)
jni::define<JKiwiBuilder>()
.template ctor<std::string, size_t, kiwi::BuildOption, kiwi::ModelType>()
.template ctor<jobject, size_t, kiwi::BuildOption, kiwi::ModelType>("ctorStream")
.template method<&JKiwiBuilder::addWord>("addWord")
.template method<&JKiwiBuilder::addWord2>("addWord")
.template method<&JKiwiBuilder::addPreAnalyzedWord>("addPreAnalyzedWord")