Files
minazuki-overlay/app-i18n/libkkc-data/files/sortlm.py
2026-04-10 09:47:38 -03:00

188 lines
6.6 KiB
Python

#!/usr/bin/python
# Copyright (C) 2011-2013 Daiki Ueno <ueno@gnu.org>
# Copyright (C) 2011-2013 Red Hat, Inc.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import struct
import marisa
import re
NGRAM = 3
NGRAM_LINE_REGEX = '^([-0-9.]+)[ \t]+([^\t]+?)(?:[ \t]+([-0-9.]+))?$'
class SortedGenerator(object):
def __init__(self, infile, output_prefix):
self.__infile = infile
self.__output_prefix = output_prefix
self.__ngram_line_regex = re.compile(NGRAM_LINE_REGEX)
self.__ngram_entries = [{} for x in range(0, NGRAM)]
self.__vocab_keyset = marisa.Keyset()
self.__input_keyset = marisa.Keyset()
self.__vocab_trie = marisa.Trie()
self.__input_trie = marisa.Trie()
self.__min_cost = 0.0
def read(self):
print("reading N-grams")
self.__read_tries()
self.__read_ngrams()
print("min cost = %lf" % self.__min_cost)
def __read_tries(self):
while True:
line = self.__infile.readline()
if line == "":
break
if line.startswith("\\1-grams"):
break
unigram_count = 0
while True:
line = self.__infile.readline()
if line == "":
break
line = line.strip()
if line == "":
break
match = self.__ngram_line_regex.match(line)
if not match:
continue
strv = match.groups()
self.__vocab_keyset.push_back(strv[1])
if not strv[1] in ("<s>", "</s>", "<UNK>"):
if "/" not in strv[1]:
continue
(input, output) = strv[1].split("/")
self.__input_keyset.push_back(input)
self.__vocab_trie.build(self.__vocab_keyset)
self.__input_trie.build(self.__input_keyset)
def __read_ngrams(self):
self.__infile.seek(0)
for n in range(1, NGRAM + 1):
while True:
line = self.__infile.readline()
if line == "":
break
if line.startswith("\\%s-grams:" % n):
break
while True:
line = self.__infile.readline()
if line == "":
break
line = line.strip()
if line == "":
break
match = self.__ngram_line_regex.match(line)
if not match:
continue
strv = match.groups()
ngram = strv[1].split(" ")
ids = []
for word in ngram:
agent = marisa.Agent()
agent.set_query(word)
if not self.__vocab_trie.lookup(agent):
continue
ids.append(agent.key_id())
cost = float(strv[0])
if cost != -99 and cost < self.__min_cost:
self.__min_cost = cost
backoff = 0.0
if strv[2]:
backoff = float(strv[2])
self.__ngram_entries[n - 1][tuple(ids)] = (cost, backoff)
def write(self):
self.__min_cost = -8.0
self.__write_tries()
self.__write_ngrams()
def __write_tries(self):
self.__vocab_trie.save(self.__output_prefix + ".1gram.index")
self.__input_trie.save(self.__output_prefix + ".input")
def __write_ngrams(self):
def quantize(cost, min_cost):
return max(0, min(65535, int(cost * 65535 / min_cost)))
def cmp_header(a, b):
return cmp(a[0], b[0])
print("writing 1-gram file")
unigram_offsets = {}
unigram_file = open("%s.1gram" % self.__output_prefix, "wb")
offset = 0
for ids, value in sorted(self.__ngram_entries[0].items()):
unigram_offsets[ids[0]] = offset
s = struct.pack("=HHH",
quantize(value[0], self.__min_cost),
quantize(value[1], self.__min_cost),
0 # reserved
)
unigram_file.write(s)
offset += 1
unigram_file.close()
print("writing 2-gram file")
bigram_offsets = {}
bigram_file = open("%s.2gram" % self.__output_prefix, "wb")
keys = list(self.__ngram_entries[1].keys())
items = [(struct.pack("=LL", ids[1], unigram_offsets[ids[0]]), ids) for ids in keys]
offset = 0
for header, ids in sorted(items, cmp=cmp_header):
value = self.__ngram_entries[1][ids]
bigram_offsets[ids] = offset
s = struct.pack("=HH",
quantize(value[0], self.__min_cost),
quantize(value[1], self.__min_cost))
bigram_file.write(header + s)
offset += 1
bigram_file.close()
if len(self.__ngram_entries[2]) > 0:
print("writing 3-gram file")
trigram_file = open("%s.3gram" % self.__output_prefix, "wb")
keys = list(self.__ngram_entries[2].keys())
items = [(struct.pack("=LL", ids[2], bigram_offsets[(ids[0], ids[1])]), ids) for ids in keys]
for header, ids in sorted(items, cmp=cmp_header):
value = self.__ngram_entries[2][ids]
s = struct.pack("=H",
quantize(value[0], self.__min_cost))
trigram_file.write(header + s)
trigram_file.close()
if __name__ == '__main__':
import sys
import argparse
parser = argparse.ArgumentParser(description='sortlm')
parser.add_argument('infile', nargs='?', type=argparse.FileType('r'),
default=sys.stdin,
help='language model file')
parser.add_argument('output_prefix', metavar='OUTPUT_PREFIX', type=str,
help='output file prefix')
args = parser.parse_args()
generator = SortedGenerator(args.infile, args.output_prefix)
generator.read();
generator.write();