📄 svm_jni.c
字号:
// vim:fdm=marker:foldmarker={%{,}%}:# include <jni.h># include "svm_jni.h"JavaParamIDs* GetJParamIDs(JNIEnv * env, jobjectArray* tdata) { JavaParamIDs *ids = my_malloc(sizeof(struct javaparamids)); // Finde und Bestimme Klassentyp von SVMLightModel ids->SVMLightModelCls = (*env)->FindClass(env,"jnisvmlight/SVMLightModel"); if (ids->SVMLightModelCls == 0) { perror("Class 'SVMLightModel' can't be found!: perror()"); exit(1); } // Bestimme IDs der Membervariablen aus der Klasse 'SVMLightModel' ids->ID_string_format = (*env)->GetFieldID(env, ids->SVMLightModelCls, "m_format", "Ljava/lang/String;"); ids->ID_long_kType = (*env)->GetFieldID(env, ids->SVMLightModelCls, "m_kType", "J"); ids->ID_long_dParam = (*env)->GetFieldID(env, ids->SVMLightModelCls, "m_dParam", "J"); ids->ID_double_gParam = (*env)->GetFieldID(env, ids->SVMLightModelCls, "m_gParam", "D"); ids->ID_double_sParam = (*env)->GetFieldID(env, ids->SVMLightModelCls, "m_sParam", "D"); ids->ID_double_rParam = (*env)->GetFieldID(env, ids->SVMLightModelCls, "m_rParam", "D"); ids->ID_string_uParam = (*env)->GetFieldID(env, ids->SVMLightModelCls, "m_uParam", "Ljava/lang/String;"); ids->ID_long_highFeatIdx = (*env)->GetFieldID(env, ids->SVMLightModelCls, "m_highFeatIdx", "J"); ids->ID_long_trainDocs = (*env)->GetFieldID(env, ids->SVMLightModelCls, "m_trainDocs", "J"); ids->ID_long_numSupVecs = (*env)->GetFieldID(env, ids->SVMLightModelCls, "m_numSupVecs", "J"); ids->ID_double_threshold = (*env)->GetFieldID(env, ids->SVMLightModelCls, "m_threshold", "D"); ids->ID_doubleArray_linWeights = (*env)->GetFieldID(env, ids->SVMLightModelCls, "m_linWeights", "[D"); ids->ID_labeledFeatureVectorArray_docs = (*env)->GetFieldID(env, ids->SVMLightModelCls, "m_docs", "[Ljnisvmlight/LabeledFeatureVector;"); if ((ids->ID_string_format && ids->ID_long_kType && ids->ID_long_dParam && ids->ID_double_sParam && ids->ID_double_rParam && ids->ID_string_uParam && ids->ID_long_highFeatIdx && ids->ID_long_trainDocs && ids->ID_long_numSupVecs && ids->ID_double_threshold && ids->ID_labeledFeatureVectorArray_docs) == 0) { perror("Can't access JFieldIDs: perror()"); exit(1); } // Bestimme ID des Konstruktors der Klasse SVMLightModel ids->ConstructorID_SVMLightModelCls = (*env)->GetMethodID( env, ids->SVMLightModelCls, "<init>", "(Ljava/lang/String;JJDDDLjava/lang/String;JJJD[Ljnisvmlight/LabeledFeatureVector;)V" ); if ( ids->ConstructorID_SVMLightModelCls == 0) { perror("Can't determine the constructor-method of SVMLightModel: perror()"); exit(1); } // Bestimmen der Groesse des uebergebenen Arrays (mit Trainingsdokumenten) tdata ids->tDataSize = (*env)->GetArrayLength(env, *tdata); if (ids->tDataSize<1) { perror("\nArray is containing no training documents!\n"); } // Abgreifen des erstes Trainingsdokuments jobject traindoc = (*env)->GetObjectArrayElement(env, *tdata, 0); if (traindoc == NULL) { perror("\ntraining document is null!\n"); } // Klassentyp des Trainingsdokuments bestimmen ids->tDataCls = (*env)->GetObjectClass(env, traindoc); if (ids->tDataCls == 0) { perror("Can't determine the class of training documents: perror()"); exit(1); } // Die IDs der Membervaribalen aus der Klasse des Trainingsdokuments bestimmen ids->ID_double_label = (*env)->GetFieldID(env, ids->tDataCls, "m_label", "D"); ids->ID_double_factor = (*env)->GetFieldID(env, ids->tDataCls, "m_factor", "D"); ids->ID_intArray_dimensions = (*env)->GetFieldID(env, ids->tDataCls, "m_dims", "[I"); ids->ID_doubleArray_values = (*env)->GetFieldID(env, ids->tDataCls, "m_vals", "[D"); // ids->MemVarID_size = (*env)->GetFieldID(env, ids->tDataCls, "m_size", "I"); if (((ids->ID_double_label) && (ids->ID_intArray_dimensions) && (ids->ID_doubleArray_values)) == 0) { perror("Can't determine jfieldIDs (training documents): perror()"); exit(1); } // Die ID des Konstruktors fuer die Klasse eines Trainingsdokuments bestimmen ids->ConstructorID_tDataCls = (*env)->GetMethodID(env, ids->tDataCls, "<init>", "()V"); if ( ids->ConstructorID_tDataCls == 0) { perror("Can't determine the constructor-method of a training document: perror()"); exit(1); } return ids;}JTrainParams* GetJTrainParamIDs(JNIEnv * env, jobject* tparam) { JTrainParams *tids = my_malloc(sizeof(struct jtrainparams)); tids->env=env; jclass tparamCls = (*env)->FindClass(env,"jnisvmlight/TrainingParameters"); if (tparamCls == 0) { perror("Can't determine the class of 'TrainingParameters': perror()"); exit(1); } tids->ID_LearnParam_lp = (*env)->GetFieldID(env, tparamCls, "m_lp", "Ljnisvmlight/LearnParam;"); tids->ID_KernelParam_kp = (*env)->GetFieldID(env, tparamCls, "m_kp", "Ljnisvmlight/KernelParam;"); if ((tids->ID_LearnParam_lp && tids->ID_KernelParam_kp) == 0) { perror("Can't find member variable 'm_lp' or 'm_kp': perror()"); exit(1); } tids->lp = (*env)->GetObjectField(env, *tparam, tids->ID_LearnParam_lp); tids->kp = (*env)->GetObjectField(env, *tparam, tids->ID_KernelParam_kp); if (tids->lp == NULL || tids->kp == NULL) { perror("Can't access 'm_lp' or 'm_kp': perror()"); exit(1); } jclass lpCls = (*env)->GetObjectClass(env,tids->lp); jclass kpCls = (*env)->GetObjectClass(env,tids->kp); if ((lpCls && kpCls) == 0) { perror("Can't determine the class of 'm_lp' or 'm_kp': perror()"); exit(1); } tids->ID_int_verbosity = (*env)->GetFieldID(env, lpCls, "verbosity", "I"); tids->ID_long_type = (*env)->GetFieldID(env, lpCls, "type", "J"); tids->ID_double_svm_c = (*env)->GetFieldID(env, lpCls, "svm_c", "D"); tids->ID_double_eps = (*env)->GetFieldID(env, lpCls, "eps", "D"); tids->ID_double_svm_costratio = (*env)->GetFieldID(env, lpCls, "svm_costratio", "D"); tids->ID_double_transduction_posratio = (*env)->GetFieldID(env, lpCls, "transduction_posratio", "D"); tids->ID_long_biased_hyperplane = (*env)->GetFieldID(env, lpCls, "biased_hyperplane", "J"); tids->ID_long_sharedslack = (*env)->GetFieldID(env, lpCls, "sharedslack", "J"); tids->ID_long_svm_maxqpsize = (*env)->GetFieldID(env, lpCls, "svm_maxqpsize", "J"); tids->ID_long_svm_newvarsinqp = (*env)->GetFieldID(env, lpCls, "svm_newvarsinqp", "J"); tids->ID_long_kernel_cache_size = (*env)->GetFieldID(env, lpCls, "kernel_cache_size", "J"); tids->ID_double_epsilon_crit = (*env)->GetFieldID(env, lpCls, "epsilon_crit", "D"); tids->ID_double_epsilon_shrink = (*env)->GetFieldID(env, lpCls, "epsilon_shrink", "D"); tids->ID_long_svm_iter_to_shrink = (*env)->GetFieldID(env, lpCls, "svm_iter_to_shrink", "J"); tids->ID_long_maxiter = (*env)->GetFieldID(env, lpCls, "maxiter", "J"); tids->ID_long_remove_inconsistent = (*env)->GetFieldID(env, lpCls, "remove_inconsistent", "J"); tids->ID_long_skip_final_opt_check = (*env)->GetFieldID(env, lpCls, "skip_final_opt_check", "J"); tids->ID_long_compute_loo = (*env)->GetFieldID(env, lpCls, "compute_loo", "J"); tids->ID_double_rho = (*env)->GetFieldID(env, lpCls, "rho", "D"); tids->ID_long_xa_depth = (*env)->GetFieldID(env, lpCls, "xa_depth", "J"); tids->ID_string_predfile = (*env)->GetFieldID(env, lpCls, "predfile", "Ljava/lang/String;"); tids->ID_string_alphafile = (*env)->GetFieldID(env, lpCls, "alphafile", "Ljava/lang/String;"); tids->ID_double_epsilon_const = (*env)->GetFieldID(env, lpCls, "epsilon_const", "D"); tids->ID_double_epsilon_a = (*env)->GetFieldID(env, lpCls, "epsilon_a", "D"); tids->ID_double_opt_precision = (*env)->GetFieldID(env, lpCls, "opt_precision", "D"); tids->ID_long_svm_c_steps = (*env)->GetFieldID(env, lpCls, "svm_c_steps", "J"); tids->ID_double_svm_c_factor = (*env)->GetFieldID(env, lpCls, "svm_c_factor", "D"); tids->ID_double_svm_costratio_unlab = (*env)->GetFieldID(env, lpCls, "svm_costratio_unlab", "D"); tids->ID_double_svm_unlabbound = (*env)->GetFieldID(env, lpCls, "svm_unlabbound", "D"); tids->ID_double_svm_cost = (*env)->GetFieldID(env, lpCls, "svm_cost", "D"); tids->ID_long_totwords = (*env)->GetFieldID(env, lpCls, "totwords", "J"); if ((tids->ID_int_verbosity && tids->ID_long_type && tids->ID_double_svm_c && tids->ID_double_eps && tids->ID_double_svm_costratio && tids->ID_double_transduction_posratio && tids->ID_long_biased_hyperplane && tids->ID_long_sharedslack && tids->ID_long_svm_maxqpsize && tids->ID_long_svm_newvarsinqp && tids->ID_long_kernel_cache_size && tids->ID_double_epsilon_crit && tids->ID_double_epsilon_shrink && tids->ID_long_svm_iter_to_shrink && tids->ID_long_maxiter && tids->ID_long_remove_inconsistent && tids->ID_long_skip_final_opt_check && tids->ID_long_compute_loo && tids->ID_double_rho && tids->ID_long_xa_depth && tids->ID_string_predfile && tids->ID_string_alphafile && tids->ID_double_epsilon_const && tids->ID_double_epsilon_a && tids->ID_double_opt_precision && tids->ID_long_svm_c_steps && tids->ID_double_svm_c_factor && tids->ID_double_svm_costratio_unlab && tids->ID_double_svm_unlabbound && tids->ID_double_svm_cost && tids->ID_long_totwords) == 0) { perror("Can't determine the jfieldIDs of class 'LearnParam': perror()"); exit(1); } tids->ID_long_kernel_type = (*env)->GetFieldID(env, kpCls, "kernel_type", "J"); tids->ID_long_poly_degree = (*env)->GetFieldID(env, kpCls, "poly_degree", "J"); tids->ID_double_rbf_gamma = (*env)->GetFieldID(env, kpCls, "rbf_gamma", "D"); tids->ID_double_coef_lin = (*env)->GetFieldID(env, kpCls, "coef_lin", "D"); tids->ID_double_coef_const = (*env)->GetFieldID(env, kpCls, "coef_const", "D"); tids->ID_string_custom = (*env)->GetFieldID(env, kpCls, "custom", "Ljava/lang/String;"); if ((tids->ID_long_kernel_type && tids->ID_long_poly_degree && tids->ID_double_rbf_gamma && tids->ID_double_coef_lin && tids->ID_double_coef_const && tids->ID_string_custom) == 0) { perror("Can't determine the jfieldIDs of class 'KernelParam': perror()"); exit(1); } jfieldID argcID = (*env)->GetFieldID(env, lpCls, "argc", "I"); jfieldID argvID = (*env)->GetFieldID(env, lpCls, "argv", "[Ljava/lang/String;"); if ((argcID && argvID) == 0) { perror("Can't find jfieldIDs of 'argc/argv'"); exit(1); } tids->argc = (*env)->GetIntField(env,tids->lp,argcID); jobjectArray sfield = (*env)->GetObjectField(env,tids->lp,argvID); if (tids->argc > 0) { tids->argv = (char**) my_malloc(sizeof(char*) * tids->argc); int j; for (j=0;j<tids->argc;j++) { jstring jstr = (*env)->GetObjectArrayElement(env, sfield, j); const char* str = (*env)->GetStringUTFChars(env, jstr, 0 ); (tids->argv)[j] = (char*) my_malloc(sizeof(char) * strlen(str)+1); strcpy((tids->argv)[j],str); (*env)->ReleaseStringUTFChars(env,jstr,str); } } return tids;}void SVMparmInit(KERNEL_CACHE* kernel_cache,LEARN_PARM* learn_parm,KERNEL_PARM* kernel_parm, MODEL* model, JTrainParams* tparm) { char type[100] = " "; jstring test; const char *str; JNIEnv* env = tparm->env; int argc = tparm->argc; char **argv = tparm->argv; verbosity = (*env)->GetIntField(env,tparm->lp,tparm->ID_int_verbosity); learn_parm->type = (*env)->GetLongField(env,tparm->lp,tparm->ID_long_type); /* learn_parm->svm_c=0.0; */ learn_parm->svm_c = (*env)->GetDoubleField(env,tparm->lp,tparm->ID_double_svm_c); /* learn_parm->eps=0.1; */ learn_parm->eps = (*env)->GetDoubleField(env,tparm->lp,tparm->ID_double_eps); /* learn_parm->svm_costratio=1.0; */ learn_parm->svm_costratio = (*env)->GetDoubleField(env,tparm->lp,tparm->ID_double_svm_costratio); /* learn_parm->transduction_posratio=-1.0; */ learn_parm->transduction_posratio = (*env)->GetDoubleField(env,tparm->lp,tparm->ID_double_transduction_posratio); /* learn_parm->biased_hyperplane=1; */ learn_parm->biased_hyperplane = (*env)->GetLongField(env,tparm->lp,tparm->ID_long_biased_hyperplane); /* learn_parm->sharedslack=0; */ learn_parm->sharedslack = (*env)->GetLongField(env,tparm->lp,tparm->ID_long_sharedslack); /* learn_parm->svm_maxqpsize=10; */ learn_parm->svm_maxqpsize = (*env)->GetLongField(env,tparm->lp,tparm->ID_long_svm_maxqpsize); /* learn_parm->svm_newvarsinqp=0; */ learn_parm->svm_newvarsinqp = (*env)->GetLongField(env,tparm->lp,tparm->ID_long_svm_newvarsinqp); /* learn_parm->kernel_cache_size=40; */ learn_parm->kernel_cache_size = (*env)->GetLongField(env,tparm->lp,tparm->ID_long_kernel_cache_size); /* learn_parm->epsilon_crit=0.001; */ learn_parm->epsilon_crit = (*env)->GetDoubleField(env,tparm->lp,tparm->ID_double_epsilon_crit); learn_parm->epsilon_shrink = (*env)->GetDoubleField(env,tparm->lp,tparm->ID_double_epsilon_shrink); /* learn_parm->svm_iter_to_shrink=-9999; */ learn_parm->svm_iter_to_shrink = (*env)->GetLongField(env,tparm->lp,tparm->ID_long_svm_iter_to_shrink); /* learn_parm->maxiter=100000; */ learn_parm->maxiter = (*env)->GetLongField(env,tparm->lp,tparm->ID_long_maxiter); /* learn_parm->remove_inconsistent=0; */ learn_parm->remove_inconsistent = (*env)->GetLongField(env,tparm->lp,tparm->ID_long_remove_inconsistent); /* learn_parm->skip_final_opt_check=0; */ learn_parm->skip_final_opt_check = (*env)->GetLongField(env,tparm->lp,tparm->ID_long_skip_final_opt_check); /* learn_parm->compute_loo=0; */ learn_parm->compute_loo = (*env)->GetLongField(env,tparm->lp,tparm->ID_long_compute_loo); /* learn_parm->rho=1.0; */ learn_parm->rho = (*env)->GetDoubleField(env,tparm->lp,tparm->ID_double_rho); /* learn_parm->xa_depth=0; */ learn_parm->xa_depth = (*env)->GetLongField(env,tparm->lp,tparm->ID_long_xa_depth); /* strcpy (learn_parm->predfile, "trans_predictions"); */ test = (*env)->GetObjectField(env, tparm->lp, tparm->ID_string_predfile); str = (*env)->GetStringUTFChars(env, test, 0 ); strcpy (learn_parm->predfile, str); (*env)->ReleaseStringUTFChars(env,test,str); /* strcpy (learn_parm->alphafile, ""); */ test = (*env)->GetObjectField(env, tparm->lp, tparm->ID_string_alphafile); str = (*env)->GetStringUTFChars(env, test, 0 ); strcpy (learn_parm->alphafile, str); (*env)->ReleaseStringUTFChars(env,test,str); learn_parm->epsilon_const = (*env)->GetDoubleField(env,tparm->lp,tparm->ID_double_epsilon_const); /* learn_parm->epsilon_a=1E-15; */ learn_parm->epsilon_a = (*env)->GetDoubleField(env,tparm->lp,tparm->ID_double_epsilon_a); learn_parm->opt_precision = (*env)->GetDoubleField(env,tparm->lp,tparm->ID_double_opt_precision); learn_parm->svm_c_steps = (*env)->GetLongField(env,tparm->lp,tparm->ID_long_svm_c_steps); learn_parm->svm_c_factor = (*env)->GetDoubleField(env,tparm->lp,tparm->ID_double_svm_c_factor); /* learn_parm->svm_costratio_unlab=1.0; */ learn_parm->svm_costratio_unlab = (*env)->GetDoubleField(env,tparm->lp,tparm->ID_double_svm_costratio_unlab); /* learn_parm->svm_unlabbound=1E-5; */ learn_parm->svm_unlabbound = (*env)->GetDoubleField(env,tparm->lp,tparm->ID_double_svm_unlabbound); learn_parm->svm_cost = (double *) my_malloc(sizeof(double)); *(learn_parm->svm_cost) = (double) (*env)->GetDoubleField(env,tparm->lp,tparm->ID_double_svm_cost); learn_parm->totwords = (*env)->GetLongField(env,tparm->lp,tparm->ID_long_svm_c_steps); /* kernel_parm->kernel_type=0; */ kernel_parm->kernel_type = (*env)->GetLongField(env,tparm->kp,tparm->ID_long_kernel_type); /* kernel_parm->poly_degree=3; */ kernel_parm->poly_degree = (*env)->GetLongField(env,tparm->kp,tparm->ID_long_poly_degree); /* kernel_parm->rbf_gamma=1.0; */ kernel_parm->rbf_gamma = (*env)->GetDoubleField(env,tparm->kp,tparm->ID_double_rbf_gamma); /* kernel_parm->coef_lin=1; */ kernel_parm->coef_lin = (*env)->GetDoubleField(env,tparm->kp,tparm->ID_double_coef_lin); /* kernel_parm->coef_const=1; */ kernel_parm->coef_const = (*env)->GetDoubleField(env,tparm->kp,tparm->ID_double_coef_const); /* strcpy(kernel_parm->custom,"empty"); */ test = (*env)->GetObjectField(env, tparm->kp, tparm->ID_string_custom); str = (*env)->GetStringUTFChars(env, test, 0 ); strcpy(kernel_parm->custom, str); (*env)->ReleaseStringUTFChars(env,test,str); if (argc>0) { int i=0; for(i=0;(i<argc) && ((argv[i])[0] == '-');i++) { switch ((argv[i])[1]) { case '?': print_help(); exit(0); case 'z': i++; strcpy(type,argv[i]); break; case 'v': i++; verbosity=atol(argv[i]); break; case 'b': i++; learn_parm->biased_hyperplane=atol(argv[i]); break; case 'i': i++; learn_parm->remove_inconsistent=atol(argv[i]); break; case 'f': i++; learn_parm->skip_final_opt_check=!atol(argv[i]); break; case 'q': i++; learn_parm->svm_maxqpsize=atol(argv[i]); break; case 'n': i++; learn_parm->svm_newvarsinqp=atol(argv[i]); break; case '#': i++; learn_parm->maxiter=atol(argv[i]); break; case 'h': i++; learn_parm->svm_iter_to_shrink=atol(argv[i]); break; case 'm': i++; learn_parm->kernel_cache_size=atol(argv[i]); break; case 'c': i++; learn_parm->svm_c=atof(argv[i]); break; case 'w': i++; learn_parm->eps=atof(argv[i]); break; case 'p': i++; learn_parm->transduction_posratio=atof(argv[i]); break; case 'j': i++; learn_parm->svm_costratio=atof(argv[i]); break; case 'e': i++; learn_parm->epsilon_crit=atof(argv[i]); break; case 'o': i++; learn_parm->rho=atof(argv[i]); break; case 'k': i++; learn_parm->xa_depth=atol(argv[i]); break; case 'x': i++; learn_parm->compute_loo=atol(argv[i]); break; case 't': i++; kernel_parm->kernel_type=atol(argv[i]); break; case 'd': i++; kernel_parm->poly_degree=atol(argv[i]); break; case 'g': i++; kernel_parm->rbf_gamma=atof(argv[i]); break; case 's': i++; kernel_parm->coef_lin=atof(argv[i]); break; case 'r': i++; kernel_parm->coef_const=atof(argv[i]); break; case 'u': i++; strcpy(kernel_parm->custom,argv[i]); break; case 'l': i++; strcpy(learn_parm->predfile,argv[i]); break; case 'a': i++; strcpy(learn_parm->alphafile,argv[i]); break; case 'y': i++; printf("Option \"-y\" is not supported in this Version of the JNI-SVMLight-interface!\n"); fflush(stdout); break; default: printf("\nUnrecognized option %s!\n\n",argv[i]); print_help(); exit(0); } } if(strcmp(type,"c")==0) { learn_parm->type=CLASSIFICATION; } else if(strcmp(type,"r")==0) { learn_parm->type=REGRESSION; } else if(strcmp(type,"p")==0) { learn_parm->type=RANKING; } else if(strcmp(type,"o")==0) { learn_parm->type=OPTIMIZATION; } else if(strcmp(type,"s")==0) { learn_parm->type=OPTIMIZATION; learn_parm->sharedslack=1; } else if (strcmp(type," ") != 0 || ((learn_parm->type & (CLASSIFICATION | REGRESSION | RANKING | OPTIMIZATION))==0)) { printf("\n\nUnknown type '%s': Valid types are 'c' (classification), 'r' regession, and 'p' preference ranking.\n",type); fflush(stdout); printf("\n\nPress Return for help\n\n"); fflush(stdout); wait_any_key(); print_help(); exit(0); } } if(learn_parm->svm_iter_to_shrink == -9999) { if(kernel_parm->kernel_type == LINEAR) learn_parm->svm_iter_to_shrink=2; else
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -