機械学習・自然言語処理の勉強メモ

学んだことのメモやまとめ

Snorkelの前処理について(実装編)②

はじめに



前回はcandidateとして文章から人物名の組み合わせを抽出するチュートリアルを見た。
kento1109.hatenablog.com

前回の固有表現抽出は対象が人物であったので、「Stanford CoreNLP」や「Spcay」を用いて容易に抽出できた。
今回はIREX標準の固有表現でないEntityを対象とする。
今回のタスクは「Chemical-Disease Relations」(薬品と病名の関係抽出)。
目標は以下の文章から

Warfarin-induced artery calcification is accelerated by growth and vitamin D.

薬品と病名の組み合わせ「("warfarin", "artery calcification")」を抽出すること。

今回は薬品名や病名といった固有表現をどのように抽出するのかを理解する。

コードを読んでいく



今回のチュートリアルは以下にある。
github.com

DocPreprocessor

今回はXMLファイルを読み込む。
XMLファイルの以下のような定義

<document>
<id>227508</id>
<passage>
<infon key="type">title</infon>
<offset>0</offset>
<text>Naloxone reverses the antihypertensive effect of clonidine.</text>
<annotation id='0'>
<infon key="type">Chemical</infon>
<infon key="MESH">D009270</infon>
<location offset='0' length='8' />
<text>Naloxone</text>
</annotation>
<annotation id='1'>
<infon key="type">Chemical</infon>
<infon key="MESH">D003000</infon>
<location offset='49' length='9' />
<text>clonidine</text>
</annotation>
</passage>
CorpusParser

パース処理に関するクラス。
前回は以下のように定義していた。(引数にSpacy()を指定していた。)

from snorkel.parser.spacy_parser import Spacy
from snorkel.parser import CorpusParser

corpus_parser = CorpusParser(parser=Spacy())
corpus_parser.apply(doc_preprocessor, count=n_docs)

今回はSpacyのNERは使えないので、以下のように定義している。

from snorkel.parser import CorpusParser
from utils import TaggerOneTagger

tagger_one = TaggerOneTagger()
TaggerOneTagger

これが今回のパースの鍵を握るクラスなのだろう。
説明を読んでみると、PubMedのNERで用いられる固有表現抽出器のようだ。

class CDRTagger(object):

    def __init__(self, fname='data/unary_tags.pkl.bz2'):   
        with bz2.BZ2File(fname, 'rb') as f:
            self.tag_dict = load(f)

class TaggerOneTagger(CDRTagger):
    
    def __init__(self, fname_tags='data/taggerone_unary_tags_cdr.pkl.bz2',
        fname_mesh='data/chem_dis_mesh_dicts.pkl.bz2'):
        with bz2.BZ2File(fname_tags, 'rb') as f:
            self.tag_dict = load(f)
        with bz2.BZ2File(fname_mesh, 'rb') as f:
            self.chem_mesh_dict, self.dis_mesh_dict = load(f)

下記のファイルが辞書の役割を果たす。

  • taggerone_unary_tags_cdr.pkl.bz2
  • chem_dis_mesh_dicts.pkl.bz2

内容を少し見てみる。

taggerone_unary_tags_cdr.pkl.bz2
import bz2
from six.moves.cPickle import load

with bz2.BZ2File('data/taggerone_unary_tags_cdr.pkl.bz2', 'rb') as f:
    tag_dict = load(f)
for i, (key, value) in enumerate(tag_dict.iteritems()):
    print "PMID:", key
    for j, val in enumerate(value):
        print val
        if j > 3:
            break
    if i > 3:
        break

PMID: 11672959
('Disease|D008206', 117, 132)
('Disease|D009102', 668, 687)
('Disease|D056486', 137, 146)
('Disease|D001172', 227, 247)
('Disease|D056486', 519, 541)
PMID: 9636837
('Chemical|D002945', 1097, 1101)
('Chemical|D016190', 472, 477)
('Disease|D020258', 191, 201)
('Chemical|D016190', 0, 11)
('Chemical|D010984', 1210, 1218)
PMID: 1445986
('Chemical|D015313', 273, 282)
('Chemical|D015313', 284, 293)
('Disease|D000743', 50, 66)
('Disease|D000743', 25, 41)
('Chemical|D002511', 186, 200)
PMID: 19274460
('Disease|D006402|D005767', 1036, 1081)
('Chemical|D003907', 294, 307)
('Chemical|C400082', 84, 94)
('Disease|D009101', 424, 426)
('Disease|D009101', 1281, 1283)
PMID: 35781
('Chemical|D009638', 770, 772)
('Disease|D002375', 605, 614)
('Chemical|C009695', 727, 741)
('Disease|D002375', 244, 254)
('Chemical|D009278', 137, 148)
chem_dis_mesh_dicts.pkl.bz2

薬品名と病名のMESHID対応辞書

with bz2.BZ2File('data/chem_dis_mesh_dicts.pkl.bz2', 'rb') as f:
    chem_mesh_dict, dis_mesh_dict = load(f)

for i, (key, value) in enumerate(chem_mesh_dict.iteritems()):
    print "MESHID:", key
    print "value:", value
    if i > 3:
        break

MESHID: 'catechols'
value: 'D002396'
MESHID: 'transferrin-binding protein complex, bacterial'
value: 'D033901'
MESHID: 'glp 1 receptor'
value: 'D000067757'
MESHID: 'carbamoyl-phosphate synthase (glutamine)'
value: 'D002223'
MESHID: 'c fes proto oncogene proteins'
value: 'D051578'

for i, (key, value) in enumerate(dis_mesh_dict.iteritems()):
    print "MESHID:", key
    print "value:", value
    if i > 3:
        break

MESHID: 'leukemia, plasmacytic'
value: 'D007952'
MESHID: 'fusion of kidney'
value: 'D000069337'
MESHID: 'epilepsies, anterior fronto-polar'
value: 'D017034'
MESHID: 'cholera infantum'
value: 'D005767'
MESHID: 'hunermann-conradi syndrome'
value: 'D002806'

これらの関係を確認する。
例えば、

PMID: 1445986
('Chemical|D015313', 273, 282)
('Chemical|D015313', 284, 293)
('Disease|D000743', 50, 66)
('Disease|D000743', 25, 41)
('Chemical|D002511', 186, 200)

の病名と薬品名を調べる。
例えば、D015313の薬品名は、

chem_mesh_dict_inv = {v:k for k, v in chem_mesh_dict.items()}
chem_mesh_dict_inv.get('D015313')
>> 'ym09330'

であり、D000743の病名は、

dis_mesh_dict_inv = {v:k for k, v in dis_mesh_dict.items()}
dis_mesh_dict_inv.get('D000743')
>> 'anemia, microangiopathic'

だということが分かる。
※PMID: 1445986の論文名は「Cefotetan-induced immune hemolytic anemia.

脱線したが、このような辞書をTaggerOneTaggerのインスタンスメソッドで定義する。
次にtagメソッドについて見ていく。
*呼び出し側

corpus_parser = CorpusParser(fn=tagger_one.tag)
corpus_parser.apply(list(doc_preprocessor))

※1行目ではfnの引数として定義しただけで実行はしていない

class CDRTagger(object):

    def tag(self, parts):
        pubmed_id, _, _, sent_start, sent_end = parts['stable_id'].split(':')
        sent_start, sent_end = int(sent_start), int(sent_end)
        tags = self.tag_dict.get(pubmed_id, {})
        for tag in tags:
            if not (sent_start <= tag[1] <= sent_end):
                continue
            offsets = [offset + sent_start for offset in parts['char_offsets']]
            toks = offsets_to_token(tag[1], tag[2], offsets, parts['lemmas'])
            for tok in toks:
                ts = tag[0].split('|')
                parts['entity_types'][tok] = ts[0]
                parts['entity_cids'][tok] = ts[1]
        return parts

class TaggerOneTagger(CDRTagger):
    
    def tag(self, parts):
        parts = super(TaggerOneTagger, self).tag(parts)
        for i, word in enumerate(parts['words']):
            tag = parts['entity_types'][i]
            if len(word) > 4 and tag is None:
                wl = word.lower()
                if wl in self.dis_mesh_dict:
                    parts['entity_types'][i] = 'Disease'
                    parts['entity_cids'][i] = self.dis_mesh_dict[wl]
                elif wl in self.chem_mesh_dict:
                    parts['entity_types'][i] = 'Chemical'
                    parts['entity_cids'][i] = self.chem_mesh_dict[wl]
        return parts

2行目のapply関数でパース処理を実行する。
今回はparser引数を指定していないので、self.parserにはStanfordCoreNLPServerが格納される。
(ここで、parser引数を指定すると対応するクラスでパース処理が行われる。)

class CorpusParser(UDFRunner):

    def __init__(self, parser=None, fn=None):
        self.parser = parser or StanfordCoreNLPServer()
        super(CorpusParser, self).__init__(CorpusParserUDF,
                                           parser=self.parser,
                                           fn=fn)

class CorpusParserUDF(UDF):

    def __init__(self, parser, fn, **kwargs):
        super(CorpusParserUDF, self).__init__(**kwargs)
        self.parser = parser
        self.req_handler = parser.connect()
        self.fn = fn

    def apply(self, x, **kwargs):
        """Given a Document object and its raw text, parse into Sentences"""
        doc, text = x
        for parts in self.req_handler.parse(doc, text):
            parts = self.fn(parts) if self.fn is not None else parts
            yield Sentence(**parts)

self.req_handler.parseの実体は、StanfordCoreNLPServer.parse

class StanfordCoreNLPServer(Parser):

    def parse(self, document, text, conn):
        '''
        Parse CoreNLP JSON results. Requires an external connection/request object to remain threadsafe
        :param document:
        :param text:
        :param conn: server connection
        :return:
        '''
        if len(text.strip()) == 0:
            sys.stderr.write("Warning, empty document {0} passed to CoreNLP".format(document.name if document else "?"))
            return

        # handle encoding (force to unicode)
        if isinstance(text, str):
            text = text.encode('utf-8', 'error')

        # POST request to CoreNLP Server
        try:
            content = conn.post(self.endpoint, text)
            content = content.decode(self.encoding)

        except socket.error as e:
            sys.stderr.write("Socket error")
            raise ValueError("Socket Error")

        # check for parsing error messages
        StanfordCoreNLPServer.validate_response(content)

        try:
            blocks = json.loads(content, strict=False)['sentences']
        except:
            warnings.warn("CoreNLP skipped a malformed document.", RuntimeWarning)

        position = 0
        for block in blocks:
            parts = defaultdict(list)
            dep_order, dep_par, dep_lab = [], [], []
            for tok, deps in zip(block['tokens'], block[StanfordCoreNLPServer.BLOCK_DEFS[self.version]]):
                # Convert PennTreeBank symbols back into characters for words/lemmas
                parts['words'].append(StanfordCoreNLPServer.PTB.get(tok['word'], tok['word']))
                parts['lemmas'].append(StanfordCoreNLPServer.PTB.get(tok['lemma'], tok['lemma']))
                parts['pos_tags'].append(tok['pos'])
                parts['ner_tags'].append(tok['ner'])
                parts['char_offsets'].append(tok['characterOffsetBegin'])
                dep_par.append(deps['governor'])
                dep_lab.append(deps['dep'])
                dep_order.append(deps['dependent'])

            # certain configuration options remove 'before'/'after' fields in output JSON (TODO: WHY?)
            # In order to create the 'text' field with correct character offsets we use
            # 'characterOffsetEnd' and 'characterOffsetBegin' to build our string from token input
            text = ""
            for t in block['tokens']:
                # shift to start of local sentence offset
                i = t['characterOffsetBegin'] - block['tokens'][0]['characterOffsetBegin']
                # add whitespace based on offsets of originalText
                text += (' ' * (i - len(text))) + t['originalText'] if len(text) != i else t['originalText']
            parts['text'] = text

            # make char_offsets relative to start of sentence
            abs_sent_offset = parts['char_offsets'][0]
            parts['char_offsets'] = [p - abs_sent_offset for p in parts['char_offsets']]
            parts['abs_char_offsets'] = [p for p in parts['char_offsets']]
            parts['dep_parents'] = sort_X_on_Y(dep_par, dep_order)
            parts['dep_labels'] = sort_X_on_Y(dep_lab, dep_order)
            parts['position'] = position

            # Add full dependency tree parse to document meta
            if 'parse' in block and document:
                tree = ' '.join(block['parse'].split())
                if 'tree' not in document.meta:
                    document.meta['tree'] = {}
                document.meta['tree'][position] = tree

            # Link the sentence to its parent document object
            parts['document'] = document if document else None

            # Add null entity array (matching null for CoreNLP)
            parts['entity_cids'] = ['O' for _ in parts['words']]
            parts['entity_types'] = ['O' for _ in parts['words']]

            # Assign the stable id as document's stable id plus absolute character offset
            abs_sent_offset_end = abs_sent_offset + parts['char_offsets'][-1] + len(parts['words'][-1])

            if document:
                parts['stable_id'] = construct_stable_id(document, 'sentence', abs_sent_offset, abs_sent_offset_end)
            position += 1
            yield parts

その後、sentence毎にtagger_one.tag(parts)を呼び出す。ここで、parts['words']がインスタンスメソッドで定義したPubMed辞書に存在するか確認する。
存在した場合、

  • parts['entity_types'][i] = 'Disease|Chemical'
  • parts['entity_cids'][i] = self.dis_mesh_dict[wl]| self.chem_mesh_dict[wl]

を格納する。
utils.py

def offsets_to_token(left, right, offset_array, lemmas, punc=set(punctuation)):
    token_start, token_end = None, None
    for i, c in enumerate(offset_array):
        if left >= c:
            token_start = i
        if c > right and token_end is None:
            token_end = i
            break
    token_end = len(offset_array) - 1 if token_end is None else token_end
    token_end = token_end - 1 if lemmas[token_end - 1] in punc else token_end
    return range(token_start, token_end)


class CDRTagger(object):

    def __init__(self, fname='data/unary_tags.pkl.bz2'):   
        with bz2.BZ2File(fname, 'rb') as f:
            self.tag_dict = load(f)

    def tag(self, parts):
        pubmed_id, _, _, sent_start, sent_end = parts['stable_id'].split(':')
        sent_start, sent_end = int(sent_start), int(sent_end)
        tags = self.tag_dict.get(pubmed_id, {})
        for tag in tags:
            if not (sent_start <= tag[1] <= sent_end):
                continue
            offsets = [offset + sent_start for offset in parts['char_offsets']]
            toks = offsets_to_token(tag[1], tag[2], offsets, parts['lemmas'])
            for tok in toks:
                ts = tag[0].split('|')
                parts['entity_types'][tok] = ts[0]
                parts['entity_cids'][tok] = ts[1]
        return parts


class TaggerOneTagger(CDRTagger):
    
    def __init__(self, fname_tags='data/taggerone_unary_tags_cdr.pkl.bz2',
        fname_mesh='data/chem_dis_mesh_dicts.pkl.bz2'):
        with bz2.BZ2File(fname_tags, 'rb') as f:
            self.tag_dict = load(f)
        with bz2.BZ2File(fname_mesh, 'rb') as f:
            self.chem_mesh_dict, self.dis_mesh_dict = load(f)

    def tag(self, parts):
        parts = super(TaggerOneTagger, self).tag(parts)
        for i, word in enumerate(parts['words']):
            tag = parts['entity_types'][i]
            if len(word) > 4 and tag is None:
                wl = word.lower()
                if wl in self.dis_mesh_dict:
                    parts['entity_types'][i] = 'Disease'
                    parts['entity_cids'][i] = self.dis_mesh_dict[wl]
                elif wl in self.chem_mesh_dict:
                    parts['entity_types'][i] = 'Chemical'
                    parts['entity_cids'][i] = self.chem_mesh_dict[wl]
        return parts
PretaggedCandidateExtractor

Candidateの抽出までは前回とほとんど同じだが、抽出部の処理は少し異なる。
前回の抽出はこう書いた。(NER='PERSON'を抽出する専用のクラスを使用した。)

from snorkel.candidates import Ngrams, CandidateExtractor
from snorkel.matchers import PersonMatcher

ngrams         = Ngrams(n_max=7)
person_matcher = PersonMatcher(longest_match_only=True)
cand_extractor = CandidateExtractor(Spouse, 
                                    [ngrams, ngrams], [person_matcher, person_matcher],
                                    symmetric_relations=False)

今回はこのままは使えないので下記のように書く。

from snorkel.candidates import PretaggedCandidateExtractor

candidate_extractor = PretaggedCandidateExtractor(ChemicalDisease, ['Chemical', 'Disease'])

名前の通りPretaggedCandidateExtractorクラスが事前に定義したタグ付き単語を抽出してくれる。
抽出部はとても長いので要点のみまとめる。

  • 指定したエンティティの辞書を用意
# Do a first pass to collect all mentions by entity type / cid
entity_idxs = dict((et, defaultdict(list)) for et in set(self.entity_types))
L = len(context.words)

>> entity_idxs = {'chemical':{[]}, 'disease':{[]}}
  • 単語のentity_typesが空かどうかチェック
  • 空でない場合、entity_typesがentity_idxsに含まれるかチェック
  • 含まれる場合、entity_idxsに追加
if context.entity_types[i] is not None:
    ets  = context.entity_types[i].split(self.entity_sep)
    cids = context.entity_cids[i].split(self.entity_sep)
    for et, cid in zip(ets, cids):
        if et in entity_idxs:
            entity_idxs[et][cid].append(i)

>> entity_idxs = {'chemical':{D002396:[catechols]}, 'disease':{D000069337:[fusion, of, kidney]}}

これで指定した固有表現の候補を抽出することが出来る。

最後に



日本語の固有表現抽出の実装を考える場合、今回のようにSnorkel内で行うのか、既にタグ付け済のコーパスを使うのかによって大きくな異なる。
Snorkel内で行う場合、今回と同様のやり方でほとんど対応可能だと思う。しかし、固有表現抽出は今回のMESHのような外部リソースを活用する方法に限定される。
別の方法で固有表現抽出を行いたい場合、事前にタグ付けが行い、それをSnorkelの入力とするのが良いと思う。
いずれの場合もparseで「Stanford CoreNLP」や「Spcay」が利用できないので、その部分の実装はどうしても必要になる。