00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043 #include "ckd_alloc.h"
00044 #include "ngram_model_arpa.h"
00045 #include "err.h"
00046 #include "pio.h"
00047 #include "listelem_alloc.h"
00048 #include "strfuncs.h"
00049
00050 #include <string.h>
00051 #include <limits.h>
00052
00053 static ngram_funcs_t ngram_model_arpa_funcs;
00054
00055 #define TSEG_BASE(m,b) ((m)->lm3g.tseg_base[(b)>>LOG_BG_SEG_SZ])
00056 #define FIRST_BG(m,u) ((m)->lm3g.unigrams[u].bigrams)
00057 #define FIRST_TG(m,b) (TSEG_BASE((m),(b))+((m)->lm3g.bigrams[b].trigrams))
00058
00059
00060
00061
00062
00063 static void
00064 init_sorted_list(sorted_list_t * l)
00065 {
00066
00067 l->list = ckd_calloc(MAX_SORTED_ENTRIES,
00068 sizeof(sorted_entry_t));
00069 l->list[0].val.l = INT_MIN;
00070 l->list[0].lower = 0;
00071 l->list[0].higher = 0;
00072 l->free = 1;
00073 }
00074
00075 static void
00076 free_sorted_list(sorted_list_t * l)
00077 {
00078 free(l->list);
00079 }
00080
00081 static lmprob_t *
00082 vals_in_sorted_list(sorted_list_t * l)
00083 {
00084 lmprob_t *vals;
00085 int32 i;
00086
00087 vals = ckd_calloc(l->free, sizeof(lmprob_t));
00088 for (i = 0; i < l->free; i++)
00089 vals[i] = l->list[i].val;
00090 return (vals);
00091 }
00092
00093 static int32
00094 sorted_id(sorted_list_t * l, int32 *val)
00095 {
00096 int32 i = 0;
00097
00098 for (;;) {
00099 if (*val == l->list[i].val.l)
00100 return (i);
00101 if (*val < l->list[i].val.l) {
00102 if (l->list[i].lower == 0) {
00103 if (l->free >= MAX_SORTED_ENTRIES) {
00104
00105 E_WARN("sorted list overflow (%d => %d)\n",
00106 *val, l->list[i].val.l);
00107 return i;
00108 }
00109
00110 l->list[i].lower = l->free;
00111 (l->free)++;
00112 i = l->list[i].lower;
00113 l->list[i].val.l = *val;
00114 return (i);
00115 }
00116 else
00117 i = l->list[i].lower;
00118 }
00119 else {
00120 if (l->list[i].higher == 0) {
00121 if (l->free >= MAX_SORTED_ENTRIES) {
00122
00123 E_WARN("sorted list overflow (%d => %d)\n",
00124 *val, l->list[i].val);
00125 return i;
00126 }
00127
00128 l->list[i].higher = l->free;
00129 (l->free)++;
00130 i = l->list[i].higher;
00131 l->list[i].val.l = *val;
00132 return (i);
00133 }
00134 else
00135 i = l->list[i].higher;
00136 }
00137 }
00138 }
00139
00140
00141
00142
00143 static int
00144 ReadNgramCounts(FILE * fp, int32 * n_ug, int32 * n_bg, int32 * n_tg)
00145 {
00146 char string[256];
00147 int32 ngram, ngram_cnt;
00148
00149
00150 do
00151 fgets(string, sizeof(string), fp);
00152 while ((strcmp(string, "\\data\\\n") != 0) && (!feof(fp)));
00153
00154 if (strcmp(string, "\\data\\\n") != 0) {
00155 E_ERROR("No \\data\\ mark in LM file\n");
00156 return -1;
00157 }
00158
00159 *n_ug = *n_bg = *n_tg = 0;
00160 while (fgets(string, sizeof(string), fp) != NULL) {
00161 if (sscanf(string, "ngram %d=%d", &ngram, &ngram_cnt) != 2)
00162 break;
00163 switch (ngram) {
00164 case 1:
00165 *n_ug = ngram_cnt;
00166 break;
00167 case 2:
00168 *n_bg = ngram_cnt;
00169 break;
00170 case 3:
00171 *n_tg = ngram_cnt;
00172 break;
00173 default:
00174 E_ERROR("Unknown ngram (%d)\n", ngram);
00175 return -1;
00176 }
00177 }
00178
00179
00180 while ((strcmp(string, "\\1-grams:\n") != 0) && (!feof(fp)))
00181 fgets(string, sizeof(string), fp);
00182
00183
00184 if ((*n_ug <= 0) || (*n_bg <= 0) || (*n_tg < 0)) {
00185 E_ERROR("Bad or missing ngram count\n");
00186 return -1;
00187 }
00188 return 0;
00189 }
00190
00191
00192
00193
00194
00195
00196 static int
00197 ReadUnigrams(FILE * fp, ngram_model_arpa_t * model)
00198 {
00199 ngram_model_t *base = &model->base;
00200 char string[256];
00201 int32 wcnt;
00202 float p1;
00203
00204 E_INFO("Reading unigrams\n");
00205
00206 wcnt = 0;
00207 while ((fgets(string, sizeof(string), fp) != NULL) &&
00208 (strcmp(string, "\\2-grams:\n") != 0)) {
00209 char *wptr[3], *name;
00210 float32 bo_wt = 0.0f;
00211 int n;
00212
00213 if ((n = str2words(string, wptr, 3)) < 2) {
00214 if (string[0] != '\n')
00215 E_WARN("Format error; unigram ignored: %s\n", string);
00216 continue;
00217 }
00218 else {
00219 p1 = (float)atof_c(wptr[0]);
00220 name = wptr[1];
00221 if (n == 3)
00222 bo_wt = (float)atof_c(wptr[2]);
00223 }
00224
00225 if (wcnt >= base->n_counts[0]) {
00226 E_ERROR("Too many unigrams\n");
00227 return -1;
00228 }
00229
00230
00231 base->word_str[wcnt] = ckd_salloc(name);
00232 if ((hash_table_enter(base->wid, base->word_str[wcnt], (void *)(long)wcnt))
00233 != (void *)(long)wcnt) {
00234 E_WARN("Duplicate word in dictionary: %s\n", base->word_str[wcnt]);
00235 }
00236 model->lm3g.unigrams[wcnt].prob1.l = logmath_log10_to_log(base->lmath, p1);
00237 model->lm3g.unigrams[wcnt].bo_wt1.l = logmath_log10_to_log(base->lmath, bo_wt);
00238 wcnt++;
00239 }
00240
00241 if (base->n_counts[0] != wcnt) {
00242 E_WARN("lm_t.ucount(%d) != #unigrams read(%d)\n",
00243 base->n_counts[0], wcnt);
00244 base->n_counts[0] = wcnt;
00245 }
00246 return 0;
00247 }
00248
00249
00250
00251
00252 static int
00253 ReadBigrams(FILE * fp, ngram_model_arpa_t * model)
00254 {
00255 ngram_model_t *base = &model->base;
00256 char string[1024];
00257 int32 w1, w2, prev_w1, bgcount;
00258 bigram_t *bgptr;
00259
00260 E_INFO("Reading bigrams\n");
00261
00262 bgcount = 0;
00263 bgptr = model->lm3g.bigrams;
00264 prev_w1 = -1;
00265
00266 while (fgets(string, sizeof(string), fp) != NULL) {
00267 float32 p, bo_wt = 0.0f;
00268 int32 p2, bo_wt2;
00269 char *wptr[4], *word1, *word2;
00270 int n;
00271
00272 wptr[3] = NULL;
00273 if ((n = str2words(string, wptr, 4)) < 3) {
00274 if (string[0] != '\n')
00275 break;
00276 continue;
00277 }
00278 else {
00279 p = (float32)atof_c(wptr[0]);
00280 word1 = wptr[1];
00281 word2 = wptr[2];
00282 if (wptr[3])
00283 bo_wt = (float32)atof_c(wptr[3]);
00284 }
00285
00286 if ((w1 = ngram_wid(base, word1)) == NGRAM_INVALID_WID) {
00287 E_ERROR("Unknown word: %s, skipping bigram (%s %s)\n",
00288 word1, word1, word2);
00289 continue;
00290 }
00291 if ((w2 = ngram_wid(base, word2)) == NGRAM_INVALID_WID) {
00292 E_ERROR("Unknown word: %s, skipping bigram (%s %s)\n",
00293 word2, word1, word2);
00294 continue;
00295 }
00296
00297
00298
00299 p = (float32)((int32)(p * 10000)) / 10000;
00300 bo_wt = (float32)((int32)(bo_wt * 10000)) / 10000;
00301
00302 p2 = logmath_log10_to_log(base->lmath, p);
00303 bo_wt2 = logmath_log10_to_log(base->lmath, bo_wt);
00304
00305 if (bgcount >= base->n_counts[1]) {
00306 E_ERROR("Too many bigrams\n");
00307 return -1;
00308 }
00309
00310 bgptr->wid = w2;
00311 bgptr->prob2 = sorted_id(&model->sorted_prob2, &p2);
00312 if (base->n_counts[2] > 0)
00313 bgptr->bo_wt2 = sorted_id(&model->sorted_bo_wt2, &bo_wt2);
00314
00315 if (w1 != prev_w1) {
00316 if (w1 < prev_w1) {
00317 E_ERROR("Bigrams not in unigram order\n");
00318 return -1;
00319 }
00320
00321 for (prev_w1++; prev_w1 <= w1; prev_w1++)
00322 model->lm3g.unigrams[prev_w1].bigrams = bgcount;
00323 prev_w1 = w1;
00324 }
00325
00326 bgcount++;
00327 bgptr++;
00328
00329 if ((bgcount & 0x0000ffff) == 0) {
00330 E_INFOCONT(".");
00331 }
00332 }
00333 if ((strcmp(string, "\\end\\") != 0)
00334 && (strcmp(string, "\\3-grams:") != 0)) {
00335 E_ERROR("Bad bigram: %s\n", string);
00336 return -1;
00337 }
00338
00339 for (prev_w1++; prev_w1 <= base->n_counts[0]; prev_w1++)
00340 model->lm3g.unigrams[prev_w1].bigrams = bgcount;
00341
00342 return 0;
00343 }
00344
00345
00346
00347
00348 static int
00349 ReadTrigrams(FILE * fp, ngram_model_arpa_t * model)
00350 {
00351 ngram_model_t *base = &model->base;
00352 char string[1024];
00353 int32 i, w1, w2, w3, prev_w1, prev_w2, tgcount, prev_bg, bg, endbg;
00354 int32 seg, prev_seg, prev_seg_lastbg;
00355 trigram_t *tgptr;
00356 bigram_t *bgptr;
00357
00358 E_INFO("Reading trigrams\n");
00359
00360 tgcount = 0;
00361 tgptr = model->lm3g.trigrams;
00362 prev_w1 = -1;
00363 prev_w2 = -1;
00364 prev_bg = -1;
00365 prev_seg = -1;
00366
00367 while (fgets(string, sizeof(string), fp) != NULL) {
00368 float32 p;
00369 int32 p3;
00370 char *wptr[4], *word1, *word2, *word3;
00371
00372 if (str2words(string, wptr, 4) != 4) {
00373 if (string[0] != '\n')
00374 break;
00375 continue;
00376 }
00377 else {
00378 p = (float32)atof_c(wptr[0]);
00379 word1 = wptr[1];
00380 word2 = wptr[2];
00381 word3 = wptr[3];
00382 }
00383
00384 if ((w1 = ngram_wid(base, word1)) == NGRAM_INVALID_WID) {
00385 E_ERROR("Unknown word: %s, skipping trigram (%s %s %s)\n",
00386 word1, word1, word2, word3);
00387 continue;
00388 }
00389 if ((w2 = ngram_wid(base, word2)) == NGRAM_INVALID_WID) {
00390 E_ERROR("Unknown word: %s, skipping trigram (%s %s %s)\n",
00391 word2, word1, word2, word3);
00392 continue;
00393 }
00394 if ((w3 = ngram_wid(base, word3)) == NGRAM_INVALID_WID) {
00395 E_ERROR("Unknown word: %s, skipping trigram (%s %s %s)\n",
00396 word3, word1, word2, word3);
00397 continue;
00398 }
00399
00400
00401
00402 p = (float32)((int32)(p * 10000)) / 10000;
00403 p3 = logmath_log10_to_log(base->lmath, p);
00404
00405 if (tgcount >= base->n_counts[2]) {
00406 E_ERROR("Too many trigrams\n");
00407 return -1;
00408 }
00409
00410 tgptr->wid = w3;
00411 tgptr->prob3 = sorted_id(&model->sorted_prob3, &p3);
00412
00413 if ((w1 != prev_w1) || (w2 != prev_w2)) {
00414
00415 if ((w1 < prev_w1) || ((w1 == prev_w1) && (w2 < prev_w2))) {
00416 E_ERROR("Trigrams not in bigram order\n");
00417 return -1;
00418 }
00419
00420 bg = (w1 !=
00421 prev_w1) ? model->lm3g.unigrams[w1].bigrams : prev_bg + 1;
00422 endbg = model->lm3g.unigrams[w1 + 1].bigrams;
00423 bgptr = model->lm3g.bigrams + bg;
00424 for (; (bg < endbg) && (bgptr->wid != w2); bg++, bgptr++);
00425 if (bg >= endbg) {
00426 E_ERROR("Missing bigram for trigram: %s", string);
00427 return -1;
00428 }
00429
00430
00431 seg = bg >> LOG_BG_SEG_SZ;
00432 for (i = prev_seg + 1; i <= seg; i++)
00433 model->lm3g.tseg_base[i] = tgcount;
00434
00435
00436 if (prev_seg < seg) {
00437 int32 tgoff = 0;
00438
00439 if (prev_seg >= 0) {
00440 tgoff = tgcount - model->lm3g.tseg_base[prev_seg];
00441 if (tgoff > 65535) {
00442 E_ERROR("Offset from tseg_base > 65535\n");
00443 return -1;
00444 }
00445 }
00446
00447 prev_seg_lastbg = ((prev_seg + 1) << LOG_BG_SEG_SZ) - 1;
00448 bgptr = model->lm3g.bigrams + prev_bg;
00449 for (++prev_bg, ++bgptr; prev_bg <= prev_seg_lastbg;
00450 prev_bg++, bgptr++)
00451 bgptr->trigrams = tgoff;
00452
00453 for (; prev_bg <= bg; prev_bg++, bgptr++)
00454 bgptr->trigrams = 0;
00455 }
00456 else {
00457 int32 tgoff;
00458
00459 tgoff = tgcount - model->lm3g.tseg_base[prev_seg];
00460 if (tgoff > 65535) {
00461 E_ERROR("Offset from tseg_base > 65535\n");
00462 return -1;
00463 }
00464
00465 bgptr = model->lm3g.bigrams + prev_bg;
00466 for (++prev_bg, ++bgptr; prev_bg <= bg; prev_bg++, bgptr++)
00467 bgptr->trigrams = tgoff;
00468 }
00469
00470 prev_w1 = w1;
00471 prev_w2 = w2;
00472 prev_bg = bg;
00473 prev_seg = seg;
00474 }
00475
00476 tgcount++;
00477 tgptr++;
00478
00479 if ((tgcount & 0x0000ffff) == 0) {
00480 E_INFOCONT(".");
00481 }
00482 }
00483 if (strcmp(string, "\\end\\") != 0) {
00484 E_ERROR("Bad trigram: %s\n", string);
00485 return -1;
00486 }
00487
00488 for (prev_bg++; prev_bg <= base->n_counts[1]; prev_bg++) {
00489 if ((prev_bg & (BG_SEG_SZ - 1)) == 0)
00490 model->lm3g.tseg_base[prev_bg >> LOG_BG_SEG_SZ] = tgcount;
00491 if ((tgcount - model->lm3g.tseg_base[prev_bg >> LOG_BG_SEG_SZ]) > 65535) {
00492 E_ERROR("Offset from tseg_base > 65535\n");
00493 return -1;
00494 }
00495 model->lm3g.bigrams[prev_bg].trigrams =
00496 tgcount - model->lm3g.tseg_base[prev_bg >> LOG_BG_SEG_SZ];
00497 }
00498 return 0;
00499 }
00500
00501 static unigram_t *
00502 new_unigram_table(int32 n_ug)
00503 {
00504 unigram_t *table;
00505 int32 i;
00506
00507 table = ckd_calloc(n_ug, sizeof(unigram_t));
00508 for (i = 0; i < n_ug; i++) {
00509 table[i].prob1.l = INT_MIN;
00510 table[i].bo_wt1.l = INT_MIN;
00511 }
00512 return table;
00513 }
00514
00515 ngram_model_t *
00516 ngram_model_arpa_read(cmd_ln_t *config,
00517 const char *file_name,
00518 logmath_t *lmath)
00519 {
00520 FILE *fp;
00521 int32 is_pipe;
00522 int32 n_unigram;
00523 int32 n_bigram;
00524 int32 n_trigram;
00525 int32 n;
00526 ngram_model_arpa_t *model;
00527 ngram_model_t *base;
00528
00529 if ((fp = fopen_comp(file_name, "r", &is_pipe)) == NULL) {
00530 E_ERROR("File %s not found\n", file_name);
00531 return NULL;
00532 }
00533
00534
00535 if (ReadNgramCounts(fp, &n_unigram, &n_bigram, &n_trigram) == -1) {
00536 fclose_comp(fp, is_pipe);
00537 return NULL;
00538 }
00539 E_INFO("ngrams 1=%d, 2=%d, 3=%d\n", n_unigram, n_bigram, n_trigram);
00540
00541
00542 model = ckd_calloc(1, sizeof(*model));
00543 base = &model->base;
00544 if (n_trigram > 0)
00545 n = 3;
00546 else if (n_bigram > 0)
00547 n = 2;
00548 else
00549 n = 1;
00550
00551 ngram_model_init(base, &ngram_model_arpa_funcs, lmath, n, n_unigram);
00552 base->n_counts[0] = n_unigram;
00553 base->n_counts[1] = n_bigram;
00554 base->n_counts[2] = n_trigram;
00555 base->writable = TRUE;
00556
00557
00558
00559
00560
00561 model->lm3g.unigrams = new_unigram_table(n_unigram + 1);
00562 model->lm3g.bigrams =
00563 ckd_calloc(n_bigram + 1, sizeof(bigram_t));
00564 if (n_trigram > 0)
00565 model->lm3g.trigrams =
00566 ckd_calloc(n_trigram, sizeof(trigram_t));
00567
00568 if (n_trigram > 0) {
00569 model->lm3g.tseg_base =
00570 ckd_calloc((n_bigram + 1) / BG_SEG_SZ + 1,
00571 sizeof(int32));
00572 }
00573 if (ReadUnigrams(fp, model) == -1) {
00574 fclose_comp(fp, is_pipe);
00575 ngram_model_free(base);
00576 return NULL;
00577 }
00578 E_INFO("%8d = #unigrams created\n", base->n_counts[0]);
00579
00580 init_sorted_list(&model->sorted_prob2);
00581 if (base->n_counts[2] > 0)
00582 init_sorted_list(&model->sorted_bo_wt2);
00583
00584 if (ReadBigrams(fp, model) == -1) {
00585 fclose_comp(fp, is_pipe);
00586 ngram_model_free(base);
00587 return NULL;
00588 }
00589
00590 base->n_counts[1] = FIRST_BG(model, base->n_counts[0]);
00591 model->lm3g.n_prob2 = model->sorted_prob2.free;
00592 model->lm3g.prob2 = vals_in_sorted_list(&model->sorted_prob2);
00593 free_sorted_list(&model->sorted_prob2);
00594 E_INFO("%8d = #bigrams created\n", base->n_counts[1]);
00595 E_INFO("%8d = #prob2 entries\n", model->lm3g.n_prob2);
00596
00597 if (base->n_counts[2] > 0) {
00598
00599 model->lm3g.n_bo_wt2 = model->sorted_bo_wt2.free;
00600 model->lm3g.bo_wt2 = vals_in_sorted_list(&model->sorted_bo_wt2);
00601 free_sorted_list(&model->sorted_bo_wt2);
00602 E_INFO("%8d = #bo_wt2 entries\n", model->lm3g.n_bo_wt2);
00603
00604 init_sorted_list(&model->sorted_prob3);
00605
00606 if (ReadTrigrams(fp, model) == -1) {
00607 fclose_comp(fp, is_pipe);
00608 ngram_model_free(base);
00609 return NULL;
00610 }
00611
00612 base->n_counts[2] = FIRST_TG(model, base->n_counts[1]);
00613 model->lm3g.n_prob3 = model->sorted_prob3.free;
00614 model->lm3g.prob3 = vals_in_sorted_list(&model->sorted_prob3);
00615 E_INFO("%8d = #trigrams created\n", base->n_counts[2]);
00616 E_INFO("%8d = #prob3 entries\n", model->lm3g.n_prob3);
00617
00618 free_sorted_list(&model->sorted_prob3);
00619
00620
00621 model->lm3g.tginfo = ckd_calloc(n_unigram, sizeof(tginfo_t *));
00622 model->lm3g.le = listelem_alloc_init(sizeof(tginfo_t));
00623 }
00624
00625 fclose_comp(fp, is_pipe);
00626 return base;
00627 }
00628
00629 int
00630 ngram_model_arpa_write(ngram_model_t *model,
00631 const char *file_name)
00632 {
00633 return -1;
00634 }
00635
00636 static int
00637 ngram_model_arpa_apply_weights(ngram_model_t *base, float32 lw,
00638 float32 wip, float32 uw)
00639 {
00640 ngram_model_arpa_t *model = (ngram_model_arpa_t *)base;
00641 lm3g_apply_weights(base, &model->lm3g, lw, wip, uw);
00642 return 0;
00643 }
00644
00645
00646 #define BINARY_SEARCH_THRESH 16
00647 static int32
00648 find_bg(bigram_t * bg, int32 n, int32 w)
00649 {
00650 int32 i, b, e;
00651
00652
00653 b = 0;
00654 e = n;
00655 while (e - b > BINARY_SEARCH_THRESH) {
00656 i = (b + e) >> 1;
00657 if ((int32)bg[i].wid < w)
00658 b = i + 1;
00659 else if ((int32)bg[i].wid > w)
00660 e = i;
00661 else
00662 return i;
00663 }
00664
00665
00666 for (i = b; (i < e) && (bg[i].wid != w); i++);
00667 return ((i < e) ? i : -1);
00668 }
00669
00670 static int32
00671 lm3g_bg_score(ngram_model_arpa_t *model, int32 lw1,
00672 int32 lw2, int32 *n_used)
00673 {
00674 int32 i, n, b, score;
00675 bigram_t *bg;
00676
00677 if (lw1 < 0) {
00678 *n_used = 1;
00679 return model->lm3g.unigrams[lw2].prob1.l;
00680 }
00681
00682 b = FIRST_BG(model, lw1);
00683 n = FIRST_BG(model, lw1 + 1) - b;
00684 bg = model->lm3g.bigrams + b;
00685
00686 if ((i = find_bg(bg, n, lw2)) >= 0) {
00687
00688 *n_used = 2;
00689 score = model->lm3g.prob2[bg[i].prob2].l;
00690 }
00691 else {
00692
00693 *n_used = 1;
00694 score = model->lm3g.unigrams[lw1].bo_wt1.l + model->lm3g.unigrams[lw2].prob1.l;
00695 }
00696
00697 return (score);
00698 }
00699
00700 static void
00701 load_tginfo(ngram_model_arpa_t *model, int32 lw1, int32 lw2)
00702 {
00703 int32 i, n, b, t;
00704 bigram_t *bg;
00705 tginfo_t *tginfo;
00706
00707
00708 tginfo = (tginfo_t *) listelem_malloc(model->lm3g.le);
00709 tginfo->w1 = lw1;
00710 tginfo->tg = NULL;
00711 tginfo->next = model->lm3g.tginfo[lw2];
00712 model->lm3g.tginfo[lw2] = tginfo;
00713
00714
00715 b = model->lm3g.unigrams[lw1].bigrams;
00716 n = model->lm3g.unigrams[lw1 + 1].bigrams - b;
00717 bg = model->lm3g.bigrams + b;
00718
00719 if ((n > 0) && ((i = find_bg(bg, n, lw2)) >= 0)) {
00720 tginfo->bowt = model->lm3g.bo_wt2[bg[i].bo_wt2].l;
00721
00722
00723 b += i;
00724 t = FIRST_TG(model, b);
00725
00726 tginfo->tg = model->lm3g.trigrams + t;
00727
00728
00729 tginfo->n_tg = FIRST_TG(model, b + 1) - t;
00730 }
00731 else {
00732 tginfo->bowt = 0;
00733 tginfo->n_tg = 0;
00734 }
00735 }
00736
00737
00738 static int32
00739 find_tg(trigram_t * tg, int32 n, int32 w)
00740 {
00741 int32 i, b, e;
00742
00743 b = 0;
00744 e = n;
00745 while (e - b > BINARY_SEARCH_THRESH) {
00746 i = (b + e) >> 1;
00747 if ((int32)tg[i].wid < w)
00748 b = i + 1;
00749 else if ((int32)tg[i].wid > w)
00750 e = i;
00751 else
00752 return i;
00753 }
00754
00755 for (i = b; (i < e) && (tg[i].wid != w); i++);
00756 return ((i < e) ? i : -1);
00757 }
00758
00759 static int32
00760 lm3g_tg_score(ngram_model_arpa_t *model, int32 lw1,
00761 int32 lw2, int32 lw3, int32 *n_used)
00762 {
00763 ngram_model_t *base = &model->base;
00764 int32 i, n, score;
00765 trigram_t *tg;
00766 tginfo_t *tginfo, *prev_tginfo;
00767
00768 if ((base->n < 3) || (lw1 < 0))
00769 return (lm3g_bg_score(model, lw2, lw3, n_used));
00770
00771 prev_tginfo = NULL;
00772 for (tginfo = model->lm3g.tginfo[lw2]; tginfo; tginfo = tginfo->next) {
00773 if (tginfo->w1 == lw1)
00774 break;
00775 prev_tginfo = tginfo;
00776 }
00777
00778 if (!tginfo) {
00779 load_tginfo(model, lw1, lw2);
00780 tginfo = model->lm3g.tginfo[lw2];
00781 }
00782 else if (prev_tginfo) {
00783 prev_tginfo->next = tginfo->next;
00784 tginfo->next = model->lm3g.tginfo[lw2];
00785 model->lm3g.tginfo[lw2] = tginfo;
00786 }
00787
00788 tginfo->used = 1;
00789
00790
00791 n = tginfo->n_tg;
00792 tg = tginfo->tg;
00793 if ((i = find_tg(tg, n, lw3)) >= 0) {
00794
00795 *n_used = 3;
00796 score = model->lm3g.prob3[tg[i].prob3].l;
00797 }
00798 else {
00799 score = tginfo->bowt + lm3g_bg_score(model, lw2, lw3, n_used);
00800 }
00801
00802 return (score);
00803 }
00804
00805 static int32
00806 ngram_model_arpa_score(ngram_model_t *base, int32 wid,
00807 int32 *history, int32 n_hist,
00808 int32 *n_used)
00809 {
00810 ngram_model_arpa_t *model = (ngram_model_arpa_t *)base;
00811
00812 switch (n_hist) {
00813 case 0:
00814
00815 *n_used = 1;
00816 return model->lm3g.unigrams[wid].prob1.l;
00817 case 1:
00818 return lm3g_bg_score(model, history[0], wid, n_used);
00819 case 2:
00820 default:
00821
00822 return lm3g_tg_score(model, history[1], history[0], wid, n_used);
00823 }
00824 }
00825
00826 static int32
00827 ngram_model_arpa_raw_score(ngram_model_t *base, int32 wid,
00828 int32 *history, int32 n_hist,
00829 int32 *n_used)
00830 {
00831 ngram_model_arpa_t *model = (ngram_model_arpa_t *)base;
00832 int32 score;
00833
00834 switch (n_hist) {
00835 case 0:
00836
00837 *n_used = 1;
00838
00839 score = model->lm3g.unigrams[wid].prob1.l - base->log_wip;
00840
00841 score = (int32)(score / base->lw);
00842
00843 if (strcmp(base->word_str[wid], "<s>") != 0) {
00844 score = logmath_log(base->lmath,
00845 logmath_exp(base->lmath, score)
00846 - logmath_exp(base->lmath,
00847 base->log_uniform + base->log_uniform_weight));
00848 }
00849 return score;
00850 case 1:
00851 score = lm3g_bg_score(model, history[0], wid, n_used);
00852 break;
00853 case 2:
00854 default:
00855
00856 score = lm3g_tg_score(model, history[1], history[0], wid, n_used);
00857 break;
00858 }
00859
00860 return (int32)((score - base->log_wip) / base->lw);
00861 }
00862
00863 static int32
00864 ngram_model_arpa_add_ug(ngram_model_t *base,
00865 int32 wid, int32 lweight)
00866 {
00867 ngram_model_arpa_t *model = (ngram_model_arpa_t *)base;
00868 return lm3g_add_ug(base, &model->lm3g, wid, lweight);
00869 }
00870
00871 static void
00872 ngram_model_arpa_free(ngram_model_t *base)
00873 {
00874 ngram_model_arpa_t *model = (ngram_model_arpa_t *)base;
00875 ckd_free(model->lm3g.unigrams);
00876 ckd_free(model->lm3g.bigrams);
00877 ckd_free(model->lm3g.trigrams);
00878 ckd_free(model->lm3g.prob2);
00879 ckd_free(model->lm3g.bo_wt2);
00880 ckd_free(model->lm3g.prob3);
00881 lm3g_tginfo_free(base, &model->lm3g);
00882 ckd_free(model->lm3g.tseg_base);
00883 }
00884
00885 static void
00886 ngram_model_arpa_flush(ngram_model_t *base)
00887 {
00888 ngram_model_arpa_t *model = (ngram_model_arpa_t *)base;
00889 lm3g_tginfo_reset(base, &model->lm3g);
00890 }
00891
00892 static ngram_funcs_t ngram_model_arpa_funcs = {
00893 ngram_model_arpa_free,
00894 ngram_model_arpa_apply_weights,
00895 ngram_model_arpa_score,
00896 ngram_model_arpa_raw_score,
00897 ngram_model_arpa_add_ug,
00898 ngram_model_arpa_flush
00899 };