From 615e6ea962c3b4262f75f8ace5a5952303315f8a Mon Sep 17 00:00:00 2001 From: moreymat Date: Fri, 27 Jan 2017 11:38:07 +0100 Subject: [PATCH 01/44] MAINT educe.rst_dt.similarity flake8 --- educe/rst_dt/similarity/rst_study.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/educe/rst_dt/similarity/rst_study.py b/educe/rst_dt/similarity/rst_study.py index 52d12a1..5f8e8a8 100644 --- a/educe/rst_dt/similarity/rst_study.py +++ b/educe/rst_dt/similarity/rst_study.py @@ -231,14 +231,14 @@ def wmd(i, j): for gov_idx, dep_idx in itertools.combinations( range(len(dtree.edus)), 2): if (doc_key, gov_idx, dep_idx) in edu_pairs_rel: - lbl = edu_pairs_rel[(doc_key, gov_idx, dep_idx)] - kept_pair = (doc_key, gov_idx, dep_idx, lbl) + kept_pair = (doc_key, gov_idx, dep_idx, + edu_pairs_rel[(doc_key, gov_idx, dep_idx)]) elif (doc_key, dep_idx, gov_idx) in edu_pairs_rel: - lbl = edu_pairs_rel[(doc_key, dep_idx, gov_idx)] - kept_pair = (doc_key, dep_idx, gov_idx, lbl) + kept_pair = (doc_key, dep_idx, gov_idx, + edu_pairs_rel[(doc_key, dep_idx, gov_idx)]) else: - lbl = 'UNRELATED' - kept_pair = (doc_key, gov_idx, dep_idx, lbl) + kept_pair = (doc_key, gov_idx, dep_idx, + 'UNRELATED') edu_pairs[-1].append(kept_pair) # transform local index of EDU in doc into global index in the list From fee0a882e8b4d6e9c0f9fb740e7b26e892f15c3d Mon Sep 17 00:00:00 2001 From: moreymat Date: Fri, 27 Jan 2017 11:46:59 +0100 Subject: [PATCH 02/44] ENH move, refactor structured metrics (inc. Parseval) from attelo --- educe/metrics/parseval.py | 316 +++++++++++++++++++++++++++ educe/metrics/scores_structured.py | 166 ++++++++++++++ educe/rst_dt/metrics/__init__.py | 0 educe/rst_dt/metrics/rst_parseval.py | 244 +++++++++++++++++++++ 4 files changed, 726 insertions(+) create mode 100644 educe/metrics/parseval.py create mode 100644 educe/metrics/scores_structured.py create mode 100644 educe/rst_dt/metrics/__init__.py create mode 100644 educe/rst_dt/metrics/rst_parseval.py diff --git a/educe/metrics/parseval.py b/educe/metrics/parseval.py new file mode 100644 index 0000000..74057b9 --- /dev/null +++ b/educe/metrics/parseval.py @@ -0,0 +1,316 @@ +"""Parseval metrics for constituency trees. + +TODO +---- +* [ ] factor out the report from the parseval scoring function, see +`sklearn.metrics.classification.classification_report` +* [ ] refactor the selection functions that enable to break down +evaluations, to avoid almost duplicates (as currently) +""" + +from __future__ import absolute_import, print_function + +import numpy as np + +from educe.metrics.scores_structured import (precision_recall_fscore_support, + unique_labels) + + +def parseval_scores(ctree_true, ctree_pred, subtree_filter=None, + exclude_root=False, lbl_fn=None, labels=None, + average=None, per_doc=False, + add_trivial_spans=False): + """Compute PARSEVAL scores for ctree_pred wrt ctree_true. + + Parameters + ---------- + ctree_true : list of list of RSTTree or SimpleRstTree + List of reference RST trees, one per document. + + ctree_pred : list of list of RSTTree or SimpleRstTree + List of predicted RST trees, one per document. + + subtree_filter : function, optional + Function to filter all local trees. + + exclude_root : boolean, defaults to True + If True, exclude the root node of both ctrees from the eval. + + lbl_fn: function, optional + Function to relabel spans. + + labels : list of string, optional + Corresponds to sklearn's target_names IMO + + average : one of {'micro', 'macro'}, optional + TODO, see scores_structured + + per_doc : boolean, optional + If True, precision, recall and f1 are computed for each document + separately then averaged over documents. + (TODO this should probably be pushed down to + `scores_structured.precision_recall_fscore_support`) + + Returns + ------- + precision : float (if average is not None) or array of float, shape =\ + [n_unique_labels] + Weighted average of the precision of each class. + + recall : float (if average is not None) or array of float, shape =\ + [n_unique_labels] + + fbeta_score : float (if average is not None) or array of float, shape =\ + [n_unique_labels] + + support_true : int (if average is not None) or array of int, shape =\ + [n_unique_labels] + The number of occurrences of each label in ``ctree_true``. + + support_pred : int (if average is not None) or array of int, shape =\ + [n_unique_labels] + The number of occurrences of each label in ``ctree_pred``. + + """ + # WIP + if add_trivial_spans: + # force inclusion of root span 1-n + exclude_root = False + + # extract descriptions of spans from the true and pred trees + spans_true = [ct.get_spans(subtree_filter=subtree_filter, + exclude_root=exclude_root) + for ct in ctree_true] + spans_pred = [ct.get_spans(subtree_filter=subtree_filter, + exclude_root=exclude_root) + for ct in ctree_pred] + + # WIP replicate eval in Li et al.'s dep parser + if add_trivial_spans: + # add trivial spans for 0-0 and 0-n + # this assumes n-n is the last span so we can get "n" as + # sp_list[-1][0][1] + spans_true = [sp_list + [((0, 0), "Root", '---', 0), + ((0, sp_list[-1][0][1]), "Root", '---', 0)] + for sp_list in spans_true] + spans_pred = [sp_list + [((0, 0), "Root", '---', 0), + ((0, sp_list[-1][0][1]), "Root", '---', 0)] + for sp_list in spans_pred] + # if label != span, change nuclearity to Satellite + spans_true = [[(x[0], "Satellite" if x[2].lower() != "span" else x[1], + x[2], x[3]) for x in sp_list] + for sp_list in spans_true] + spans_pred = [[(x[0], "Satellite" if x[2].lower() != "span" else x[1], + x[2], x[3]) for x in sp_list] + for sp_list in spans_pred] + # end WIP + # use lbl_fn to define labels + if lbl_fn is not None: + spans_true = [[(span[0], lbl_fn(span)) for span in spans] + for spans in spans_true] + spans_pred = [[(span[0], lbl_fn(span)) for span in spans] + for spans in spans_pred] + + # NEW gather present labels + present_labels = unique_labels(spans_true, spans_pred) + if labels is None: + labels = present_labels + else: + # currently not tested + labels = np.hstack([labels, np.setdiff1d(present_labels, labels, + assume_unique=True)]) + # end NEW labels + + if per_doc: + # non-standard variant that computes scores per doc then + # averages them over docs ; this variant is implemented in DPLP + # where it is mistaken for the standard version + scores = [] + for doc_spans_true, doc_spans_pred in zip(spans_true, spans_pred): + p, r, f1, s_true, s_pred = precision_recall_fscore_support( + [doc_spans_true], [doc_spans_pred], labels=labels, + average=average) + scores.append((p, r, f1, s_true, s_pred)) + p, r, f1, s_true, s_pred = ( + np.array([x[0] for x in scores]).mean(), + np.array([x[1] for x in scores]).mean(), + np.array([x[2] for x in scores]).mean(), + np.array([x[3] for x in scores]).sum(), + np.array([x[4] for x in scores]).sum() + ) + else: + # standard version of this eval + p, r, f1, s_true, s_pred = precision_recall_fscore_support( + spans_true, spans_pred, labels=labels, average=average) + + return p, r, f1, s_true, s_pred, labels + + +def parseval_report(ctree_true, ctree_pred, exclude_root=False, + subtree_filter=None, lbl_fns=None, digits=4, + print_support_pred=True, per_doc=False, + add_trivial_spans=False): + """Build a text report showing the PARSEVAL discourse metrics. + + This is the simplest report we need to generate, it corresponds + to the arrays of results from the literature. + Metrics are calculated globally (average='micro'). + + Parameters + ---------- + ctree_true: TODO + TODO + ctree_pred: TODO + TODO + metric_types: list of strings, optional + Metrics that need to be included in the report ; if None is + given, defaults to ['S', 'S+N', 'S+R', 'S+N+R']. + digits: int, defaults to 4 + Number of decimals to print. + print_support_pred: boolean, defaults to True + If True, the predicted support, i.e. the number of predicted + spans, is also displayed. This is useful for non-binary ctrees + as the number of spans in _true and _pred can differ. + span_sel: TODO + TODO + per_doc: boolean, defaults to False + If True, compute p, r, f for each doc separately then compute the + mean of each score over docs. This is *not* the correct + implementation, but it corresponds to that in DPLP. + """ + if lbl_fns is None: + # we require a labelled span to be a pair (span, lbl) + # where span and lbl can be anything, for example + # * span = (span_beg, span_end) + # * lbl = (nuc, rel) + lbl_fns = [('Labelled Span', lambda span_lbl: span_lbl[1])] + + metric_types = [k for k, v in lbl_fns] + + # prepare scaffold for report + width = max(len(str(x)) for x in metric_types) + width = max(width, digits) + headers = ["precision", "recall", "f1-score", "support", "sup_pred"] + fmt = '%% %ds' % width # first col: class name + fmt += ' ' + fmt += ' '.join(['% 9s' for _ in headers]) + fmt += '\n' + headers = [""] + headers + report = fmt % tuple(headers) + report += '\n' + + # compute scores + metric_scores = dict() + for metric_type, lbl_fn in lbl_fns: + p, r, f1, s_true, s_pred, labels = parseval_scores( + ctree_true, ctree_pred, subtree_filter=subtree_filter, + exclude_root=exclude_root, lbl_fn=lbl_fn, labels=None, + average='micro', per_doc=per_doc, + add_trivial_spans=add_trivial_spans) + metric_scores[metric_type] = (p, r, f1, s_true, s_pred) + + # fill report + for metric_type in metric_types: + (p, r, f1, s_true, s_pred) = metric_scores[metric_type] + values = [metric_type] + for v in (p, r, f1): + values += ["{0:0.{1}f}".format(v, digits)] + values += ["{0}".format(s_true)] # support_true + values += ["{0}".format(s_pred)] # support_pred + report += fmt % tuple(values) + + return report + + +def parseval_detailed_report(ctree_true, ctree_pred, exclude_root=False, + subtree_filter=None, lbl_fn=None, + labels=None, sort_by_support=True, + digits=4, per_doc=False): + """Build a text report showing the PARSEVAL discourse metrics. + + FIXME model after sklearn.metrics.classification.classification_report + + Parameters + ---------- + ctree_true : list of RSTTree or SimpleRstTree + Ground truth (correct) target structures. + + ctree_pred : list of RSTTree or SimpleRstTree + Estimated target structures as predicted by a parser. + + labels : list of string, optional + Relation labels to include in the evaluation. + FIXME Corresponds more to target_names in sklearn IMHO. + + lbl_fn : function from tuple((int, int), (string, string)) to string + Label extraction function + + digits : int + Number of digits for formatting output floating point values. + + Returns + ------- + report : string + Text summary of the precision, recall, F1 score, support for each + class (or micro-averaged over all classes). + + """ + if lbl_fn is None: + # we require a labelled span to be a pair (span, lbl) + # where span and lbl can be anything, for example + # * span = (span_beg, span_end) + # * lbl = (nuc, rel) + lbl_fn = ('Labelled Span', lambda span_lbl: span_lbl[1]) + # FIXME param lbl_fn is in fact a pair (metric_type, lbl_fn) + metric_type, lbl_fn = lbl_fn + + # call with average=None to compute per-class scores, then + # compute average here and print it + p, r, f1, s_true, s_pred, labels = parseval_scores( + ctree_true, ctree_pred, subtree_filter=subtree_filter, + exclude_root=exclude_root, lbl_fn=lbl_fn, labels=labels, + average=None, per_doc=per_doc) + + # scaffold for report + last_line_heading = 'avg / total' + + width = max(len(str(lbl)) for lbl in labels) + width = max(width, len(last_line_heading), digits) + + headers = ["precision", "recall", "f1-score", "support", "sup_pred"] + fmt = '%% %ds' % width # first col: class name + fmt += ' ' + fmt += ' '.join(['% 9s' for _ in headers]) + fmt += '\n' + + headers = [""] + headers + report = fmt % tuple(headers) + report += '\n' + + # map labels to indices, possibly sorted by their support + sorted_ilbls = enumerate(labels) + if sort_by_support: + sorted_ilbls = sorted(sorted_ilbls, key=lambda x: s_true[x[0]], + reverse=True) + # one line per label + for i, label in sorted_ilbls: + values = [label] + for v in (p[i], r[i], f1[i]): + values += ["{0:0.{1}f}".format(v, digits)] + values += ["{0}".format(s_true[i])] + values += ["{0}".format(s_pred[i])] + report += fmt % tuple(values) + + report += '\n' + + # last line ; compute averages + values = [last_line_heading] + for v in (np.average(p, weights=s_true), + np.average(r, weights=s_true), + np.average(f1, weights=s_true)): + values += ["{0:0.{1}f}".format(v, digits)] + values += ['{0}'.format(np.sum(s_true))] + values += ['{0}'.format(np.sum(s_pred))] + report += fmt % tuple(values) + + return report diff --git a/educe/metrics/scores_structured.py b/educe/metrics/scores_structured.py new file mode 100644 index 0000000..ba07ce4 --- /dev/null +++ b/educe/metrics/scores_structured.py @@ -0,0 +1,166 @@ +"""Classification metrics for structured outputs. + +""" + +from collections import Counter +from itertools import chain, izip + +import numpy as np + + +def _unique_labels(y): + """Set of unique labels in y""" + return set(y_ij[1] for y_ij in + chain.from_iterable(y_i for y_i in y)) + + +def unique_labels(*ys): + """Extract an ordered array of unique labels. + + Parameters + ---------- + elt_type: string + Type of each element, determines how to find the label + + See also + -------- + This is the structured version of + `sklearn.utils.multiclass.unique_labels` + """ + ys_labels = set(chain.from_iterable(_unique_labels(y) for y in ys)) + # TODO check the set of labels contains a unique (e.g. string) type + # of values + return np.array(sorted(ys_labels)) + + +def precision_recall_fscore_support(y_true, y_pred, labels=None, + average=None, return_support_pred=True): + """Compute precision, recall, F-measure and support for each class. + + The support is the number of occurrences of each class in + ``y_true``. + + This is essentially a structured version of + sklearn.metrics.classification.precision_recall_fscore_support . + + It should apply equally well to lists of constituency tree spans + and lists of dependency edges. + + Parameters + ---------- + y_true: list of iterable + Ground truth target structures, encoded in a sparse format (e.g. + list of edges or span descriptions). + + y_pred: list of iterable + Estimated target structures, encoded in a sparse format (e.g. list + of edges or span descriptions). + + labels: list, optional + The set of labels to include, and their order if ``average is + None``. + + average: string, [None (default), 'binary', 'micro', 'macro'] + If ``None``, the scores for each class are returned. Otherwise, + this determines the type of averaging performed on the data: + + ``'binary'``: + Only report results for the positive class. + This is applicable only if targets are binary. + ``'micro'``: + Calculate metrics globally by counting the total true + positives, false negatives and false positives. + ``'macro'``: + Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + + return_support_pred: boolean, True by default + If True, output the support of the prediction. This is useful + for structured prediction because y_true and y_pred can differ + in length. + + Returns + ------- + precision: float (if average is not None) or array of float, shape=\ + [n_unique_labels] + + recall: float (if average is not None) or array of float, shape=\ + [n_unique_labels] + + fscore: float (if average is not None) or array of float, shape=\ + [n_unique_labels] + + support: int (if average is not None) or array of int, shape=\ + [n_unique_labels] + The number of occurrences of each label in ``ctree_true``. + + support_pred: int (if average is not None) or array of int, shape=\ + [n_unique_labels], if ``return_support_pred``. + If The number of occurrences of each label in ``ctree_pred``. + """ + average_options = frozenset([None, 'micro', 'macro']) + if average not in average_options: + raise ValueError('average has to be one of' + + str(average_options)) + # TMP + if average == 'macro': + raise NotImplementedError('average currently has to be micro or None') + # end TMP + + # gather an ordered list of unique labels from y_true and y_pred + present_labels = unique_labels(y_true, y_pred) + + if labels is None: + labels = present_labels + # n_labels = None + else: + # EXPERIMENTAL + labels = [lbl for lbl in labels if lbl in present_labels] + # n_labels = len(labels) + # FIXME complete/fix this + # raise ValueError('Parameter `labels` is currently unsupported') + # end EXPERIMENTAL + + # compute tp_sum, pred_sum, true_sum + # true positives for each tree + tp = [set(yi_true) & set(yi_pred) + for yi_true, yi_pred in izip(y_true, y_pred)] + + # TODO find a nicer and faster design that resembles sklearn's, e.g. + # use np.bincount instead of collections.Counter + tp_sum = Counter(y_ij[1] for y_ij in chain.from_iterable(tp)) + true_sum = Counter(y_ij[1] for y_ij in chain.from_iterable(y_true)) + pred_sum = Counter(y_ij[1] for y_ij in chain.from_iterable(y_pred)) + # transform to np arrays of floats + tp_sum = np.array([float(tp_sum[lbl]) for lbl in labels]) + true_sum = np.array([float(true_sum[lbl]) for lbl in labels]) + pred_sum = np.array([float(pred_sum[lbl]) for lbl in labels]) + + # TODO rewrite to compute by summing over scores broken down by label + if average == 'micro': + tp_sum = np.array([tp_sum.sum()]) + true_sum = np.array([true_sum.sum()]) + pred_sum = np.array([pred_sum.sum()]) + + # finally compute the desired statistics + # when the div denominator is 0, assign 0.0 (instead of np.inf) + precision = tp_sum / pred_sum + precision[pred_sum == 0] = 0.0 + + recall = tp_sum / true_sum + recall[true_sum == 0] = 0.0 + + f_score = 2 * (precision * recall) / (precision + recall) + f_score[precision + recall == 0] = 0.0 + + if average is not None: + precision = np.average(precision) + recall = np.average(recall) + f_score = np.average(f_score) + true_sum = np.average(true_sum) # != sklearn: we keep the support + pred_sum = np.average(pred_sum) + + if return_support_pred: + return precision, recall, f_score, true_sum, pred_sum + else: + return precision, recall, f_score, true_sum diff --git a/educe/rst_dt/metrics/__init__.py b/educe/rst_dt/metrics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/educe/rst_dt/metrics/rst_parseval.py b/educe/rst_dt/metrics/rst_parseval.py new file mode 100644 index 0000000..af8ba1c --- /dev/null +++ b/educe/rst_dt/metrics/rst_parseval.py @@ -0,0 +1,244 @@ +"""PARSEVAL metrics adapted for RST constituency trees. + +References +---------- +.. [1] `Daniel Marcu (2000). "The theory and practice of discourse + parsing and summarization." MIT press. + +""" + +from __future__ import absolute_import, print_function + +from educe.metrics.parseval import (parseval_scores, parseval_report, + parseval_detailed_report) + + +# label extraction functions +LBL_FNS = [ + ('S', lambda span: 1), + ('S+N', lambda span: span[1]), + ('S+R', lambda span: span[2]), + ('S+N+R', lambda span: '{}-{}'.format(span[2], span[1])), + # WIP 2016-11-10 add head to evals + ('S+H', lambda span: span[3]), + ('S+N+H', lambda span: '{}-{}'.format(span[1], span[3])), + ('S+R+H', lambda span: '{}-{}'.format(span[2], span[3])), + ('S+N+R+H', lambda span: '{}-{}'.format(span[2], span[1])), + # end WIP head +] + + +def rst_parseval_scores(ctree_true, ctree_pred, lbl_fn, subtree_filter=None, + labels=None, average=None): + """Compute RST PARSEVAL scores for ctree_pred wrt ctree_true. + + Notably, the root node of both ctrees is excluded from the scoring + procedure. + + Parameters + ---------- + ctree_true : list of list of RSTTree or SimpleRstTree + List of reference RST trees, one per document. + + ctree_pred : list of list of RSTTree or SimpleRstTree + List of predicted RST trees, one per document. + + lbl_fn : function, optional + Function to relabel spans. + + subtree_filter : function, optional + Function to filter all local trees. + + labels : list of string, optional + Corresponds to sklearn's target_names IMO + + average : one of {'micro', 'macro'}, optional + TODO, see scores_structured + + Returns + ------- + precision : float (if average is not None) or array of float, shape =\ + [n_unique_labels] + Weighted average of the precision of each class. + + recall : float (if average is not None) or array of float, shape =\ + [n_unique_labels] + + fbeta_score : float (if average is not None) or array of float, shape =\ + [n_unique_labels] + + support : int (if average is not None) or array of int, shape =\ + [n_unique_labels] + The number of occurrences of each label in ``ctree_true``. + + """ + return parseval_scores(ctree_true, ctree_pred, + subtree_filter=subtree_filter, + exclude_root=True, lbl_fn=lbl_fn, + labels=labels, average=average) + + +def rst_parseval_report(ctree_true, ctree_pred, ctree_type='RST', + subtree_filter=None, metric_types=None, + digits=4, print_support_pred=True, + per_doc=False, + add_trivial_spans=False, + stringent=False): + """Build a text report showing the PARSEVAL discourse metrics. + + This is the simplest report we need to generate, it corresponds + to the arrays of results from the literature. + Metrics are calculated globally (average='micro'). + + Parameters + ---------- + ctree_true: TODO + TODO + + ctree_pred: TODO + TODO + + ctree_type : one of {'RST', 'SimpleRST'}, defaults to 'RST' + Type of ctrees considered in the evaluation procedure. + 'RST' is the standard type of ctrees used in the RST corpus, + it triggers the exclusion of the root node from the evaluation + but leaves are kept. + 'SimpleRST' is a binarized variant of RST trees where each + internal node corresponds to an attachment decision ; in other + words, it is a binary ctree where the nuclearity and relation label + are moved one node up compared to the standard RST trees. This + triggers the exclusion of leaves from the eval, but the root node + is kept. + + subtree_filter: function, optional + Function to filter all local trees. + + metric_types : list of strings, optional + Metrics that need to be included in the report ; if None is + given, defaults to ['S', 'S+N', 'S+R', 'S+N+R']. + + digits : int, defaults to 4 + Number of decimals to print. + + print_support_pred : boolean, defaults to True + If True, the predicted support, i.e. the number of predicted + spans, is also displayed. This is useful for non-binary ctrees + as the number of spans in _true and _pred can differ. + + per_doc : boolean, defaults to False + If True, compute p, r, f for each doc separately then compute the + mean of each score over docs. This is *not* the correct + implementation, but it corresponds to that in DPLP. + + add_trivial_spans : boolean, defaults to False + If True, trivial spans 0-0, 0-n, 1-n are added ; this is meant to + replicate the evaluation procedure of Li et al.'s dependency RST + parser. + + stringent : boolean, defaults to False + TODO + """ + # filter root or leaves, depending on the type of ctree + if ctree_type not in ['RST', 'SimpleRST']: + raise ValueError("ctree_type should be one of {'RST', 'SimpleRST'}") + if ctree_type == 'RST': + # standard RST ctree: exclude root + exclude_root = True + subtree_filter = subtree_filter + elif ctree_type == 'SimpleRST': + # SimpleRST variant: keep root, exclude leaves + exclude_root = False # TODO try True first, should get same as before + not_leaf = lambda t: t.height() > 2 # TODO unit test! + if subtree_filter is None: + subtree_filter = not_leaf + else: + subtree_filter = lambda t: not_leaf(t) and subtree_filter(t) + + # select metrics and the corresponding functions + if metric_types is None: + # metric_types = ['S', 'S+N', 'S+R', 'S+N+R'] + metric_types = [x[0] for x in LBL_FNS] + if set(metric_types) - set(x[0] for x in LBL_FNS): + raise ValueError('Unknown metric types in {}'.format(metric_types)) + metric2lbl_fn = dict(LBL_FNS) + lbl_fns = [(metric_type, metric2lbl_fn[metric_type]) + for metric_type in metric_types] + + return parseval_report(ctree_true, ctree_pred, exclude_root=exclude_root, + subtree_filter=subtree_filter, lbl_fns=lbl_fns, + digits=digits, + print_support_pred=print_support_pred, + per_doc=per_doc, + add_trivial_spans=add_trivial_spans) + + +def rst_parseval_detailed_report(ctree_true, ctree_pred, ctree_type='RST', + subtree_filter=None, metric_type='S+R', + labels=None, sort_by_support=True, + digits=4, per_doc=False): + """Build a text report showing the PARSEVAL discourse metrics per label. + + Metrics are calculated globally (average='micro'). + + Parameters + ---------- + ctree_true: TODO + TODO + + ctree_pred: TODO + TODO + + ctree_type : one of {'RST', 'SimpleRST'}, defaults to 'RST' + Type of ctrees considered in the evaluation procedure. + 'RST' is the standard type of ctrees used in the RST corpus, + it triggers the exclusion of the root node from the evaluation + but leaves are kept. + 'SimpleRST' is a binarized variant of RST trees where each + internal node corresponds to an attachment decision ; in other + words, it is a binary ctree where the nuclearity and relation label + are moved one node up compared to the standard RST trees. This + triggers the exclusion of leaves from the eval, but the root node + is kept. + + subtree_filter: function, optional + Function to filter all local trees. + + metric_type : one of {'S+R', 'S+N+R'}, defaults to 'S+R' + Metric that need to be included in the report. + + digits : int, defaults to 4 + Number of decimals to print. + + per_doc: boolean, defaults to False + If True, compute p, r, f for each doc separately then compute the + mean of each score over docs. This is *not* the correct + implementation, but it corresponds to that in DPLP. + + """ + # filter root or leaves, depending on the type of ctree + if ctree_type not in ['RST', 'SimpleRST']: + raise ValueError("ctree_type should be one of {'RST', 'SimpleRST'}") + if ctree_type == 'RST': + # standard RST ctree: exclude root + exclude_root = True + subtree_filter = subtree_filter + elif ctree_type == 'SimpleRST': + # SimpleRST variant: keep root, exclude leaves + exclude_root = False # TODO try True first, should get same as before + not_leaf = lambda t: t.height() > 2 # TODO unit test! + if subtree_filter is None: + subtree_filter = not_leaf + else: + subtree_filter = lambda t: not_leaf(t) and subtree_filter(t) + + # select metrics and the corresponding functions + if metric_type not in set(x[0] for x in LBL_FNS): + raise ValueError('Unknown metric type: {}'.format(metric_type)) + metric2lbl_fn = dict(LBL_FNS) + lbl_fn = (metric_type, metric2lbl_fn[metric_type]) + + return parseval_detailed_report( + ctree_true, ctree_pred, exclude_root=exclude_root, + subtree_filter=subtree_filter, lbl_fn=lbl_fn, + labels=labels, sort_by_support=sort_by_support, + digits=digits, per_doc=per_doc) From 43ebe631c1daa668f7858e9441137d438573a0f6 Mon Sep 17 00:00:00 2001 From: moreymat Date: Fri, 27 Jan 2017 15:03:22 +0100 Subject: [PATCH 03/44] DOC+MAINT educe.external improve docstrings, style --- educe/external/corenlp.py | 2 +- educe/external/parser.py | 26 ++++++++++++++++++++------ educe/external/postag.py | 18 +++++++++++++++++- 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/educe/external/corenlp.py b/educe/external/corenlp.py index e959c5d..e27746f 100644 --- a/educe/external/corenlp.py +++ b/educe/external/corenlp.py @@ -46,7 +46,7 @@ class CoreNlpToken(postag.Token): Attributes ---------- - features: dict(string, string) + features : dict(string, string) Additional info found by corenlp about the token (eg. `x.features['lemma']`) """ diff --git a/educe/external/parser.py b/educe/external/parser.py index 179bb0f..9ba3d1e 100644 --- a/educe/external/parser.py +++ b/educe/external/parser.py @@ -123,24 +123,38 @@ def text_span(self): @classmethod def build(cls, tree, tokens): - """ - Build an educe tree by combining an existing NLTK tree with + """Build an educe tree by combining an existing NLTK tree with some replacement leaves. The replacement leaves should correspond 1:1 to the leaves of the original tree (for example, they may contain features related to - those words + those words). + + Parameters + ---------- + tree : nltk.Tree + Original NLTK tree. + tokens : iterable + List of replacement leaves. + + Returns + ------- + ctree : ConstituencyTree + ConstituencyTree where the internal nodes have the same + labels as in the original NLTK tree and the leaves + correspond to the given list of tokens. """ toks = deque(tokens) def step(t): """Recursive helper for tree building""" if not isinstance(t, nltk.tree.Tree): - if toks: - return toks.popleft() - else: + # leaf + if not toks: raise Exception('Must have same number of input tokens ' 'as leaves in the tree') + return toks.popleft() + # internal node, recurse to kids return cls(t.label(), [step(kid) for kid in t]) return step(tree) diff --git a/educe/external/postag.py b/educe/external/postag.py index df8d71c..937f725 100644 --- a/educe/external/postag.py +++ b/educe/external/postag.py @@ -176,7 +176,23 @@ def token_spans(text, tokens, offset=0): Spans are relative to the start of the string itself, but can be shifted by passing an offset (the start of the original string's - span) + span). + + Parameters + ---------- + text : string + Base text. + + tokens : sequence of RawToken + Sequence of raw tokens in the text. + + offset : int, defaults to 0 + Offset for spans. + + Returns + ------- + res : list of Token + Sequence of proper educe `Token`s with their span. """ token_words = [tok.word for tok in tokens] spans = generic_token_spans(text, token_words, offset) From 06372288d881b06d66db03e06d04f5f743052b5b Mon Sep 17 00:00:00 2001 From: moreymat Date: Fri, 27 Jan 2017 15:37:44 +0100 Subject: [PATCH 04/44] MAINT educe.pdtb flake8, pylint --- educe/pdtb/corpus.py | 4 +- educe/pdtb/parse.py | 296 ++++++++++++++++++++------------- educe/pdtb/pdtbx.py | 214 +++++++++++++++--------- educe/pdtb/tests.py | 172 ++++++++++--------- educe/pdtb/util/cmd/extract.py | 3 + educe/pdtb/util/cmd/tmp.py | 17 +- educe/pdtb/util/cmd/xml_.py | 12 +- 7 files changed, 425 insertions(+), 293 deletions(-) diff --git a/educe/pdtb/corpus.py b/educe/pdtb/corpus.py index ec5b0da..44605da 100644 --- a/educe/pdtb/corpus.py +++ b/educe/pdtb/corpus.py @@ -13,10 +13,10 @@ import educe.corpus from . import parse + # --------------------------------------------------------------------- # Corpus # --------------------------------------------------------------------- - class Reader(educe.corpus.Reader): """ See `educe.corpus.Reader` for details @@ -46,7 +46,7 @@ def slurp_subcorpus(self, cfiles, verbose=False): (counter, len(cfiles))) fname = cfiles[k] annotations = parse.parse(fname) - #annotations.set_origin(k) + # annotations.set_origin(k) corpus[k] = annotations counter = counter+1 if verbose: diff --git a/educe/pdtb/parse.py b/educe/pdtb/parse.py index e78838f..d5a8f9e 100755 --- a/educe/pdtb/parse.py +++ b/educe/pdtb/parse.py @@ -37,7 +37,6 @@ import codecs import re -import funcparserlib.parser as fp import sys if sys.version > '3': @@ -46,13 +45,15 @@ else: from StringIO import StringIO +import funcparserlib.parser as fp + + # --------------------------------------------------------------------- # parse results # --------------------------------------------------------------------- - class PdtbItem(object): @classmethod - def _prefered_order(self): + def _prefered_order(cls): """ Preferred order for printing key/value pairs """ @@ -63,10 +64,10 @@ def _prefered_order(self): 'arg1', 'arg2'] def _substr(self): - d = self.__dict__ - ks1 = [ k for k in self._prefered_order() if k in d ] - ks2 = [ k for k in d if k not in self._prefered_order() ] - return '\n '.join('%s = %s' % (k,d[k]) for k in ks1 + ks2) + d = self.__dict__ + ks1 = [k for k in self._prefered_order() if k in d] + ks2 = [k for k in d if k not in self._prefered_order()] + return '\n '.join('%s = %s' % (k, d[k]) for k in ks1 + ks2) def __str__(self): return '%s(%s)' % (self.__class__.__name__, self._substr()) @@ -82,29 +83,32 @@ def __eq__(self, other): def __ne__(self, other): return not self == other + class GornAddress(PdtbItem): def __init__(self, parts): self.parts = parts def __str__(self): - return '.'.join(map(str,self.parts)) + return '.'.join(str(x) for x in self.parts) + class Attribution(PdtbItem): def __init__(self, source, type, polarity, determinacy, selection=None): - self.source = source - self.type = type - self.polarity = polarity + self.source = source + self.type = type + self.polarity = polarity self.determinacy = determinacy - self.selection = selection + self.selection = selection def _substr(self): selStr = '@ %s' % self.selection._substr() if self.selection else '' return '%s %s %s %s%s' %\ (self.source, self.type, self.polarity, self.determinacy, selStr) + class InferenceSite(PdtbItem): def __init__(self, strpos, sentnum): - self.strpos = strpos + self.strpos = strpos self.sentnum = sentnum def _substr(self): @@ -114,6 +118,7 @@ def _substr(self): def _init_copy(cls, self, other): cls.__init__(self, other.strpos, other.sentnum) + class Selection(PdtbItem): def __init__(self, span, gorn, text): self.span = span @@ -130,12 +135,13 @@ def _substr(self): def _init_copy(cls, self, other): cls.__init__(self, other.span, other.gorn, other.text) + class Connective(PdtbItem): def __init__(self, text, semclass1, semclass2=None): self.text = text - assert(isinstance(semclass1, SemClass)) + assert isinstance(semclass1, SemClass) if semclass2: - assert(isinstance(semclass2, SemClass)) + assert isinstance(semclass2, SemClass) self.semclass1 = semclass1 self.semclass2 = semclass2 @@ -145,6 +151,7 @@ def _substr(self): fields.append(self.semclass2._substr()) return ' | '.join(fields) + class SemClass(PdtbItem): def __init__(self, klass): self.klass = klass @@ -152,23 +159,27 @@ def __init__(self, klass): def _substr(self): return '.'.join(self.klass) + class Sup(Selection): def __init__(self, selection): Selection._init_copy(self, selection) + class Arg(Selection): def __init__(self, selection, attribution=None, sup=None): Selection._init_copy(self, selection) if attribution: - assert(isinstance(attribution, Attribution)) + assert isinstance(attribution, Attribution) if sup: - assert(isinstance(sup , Sup)) + assert isinstance(sup, Sup) self.attribution = attribution - self.sup = sup + self.sup = sup def _substr(self): sup_str = ' + %s' % self.sup if self.sup else '' - return '%s | %s%s' % (Selection._substr(self), self.attribution, sup_str) + return '%s | %s%s' % (Selection._substr(self), self.attribution, + sup_str) + class Relation(PdtbItem): """ @@ -183,30 +194,33 @@ def __init__(self, args): self.arg1 = Arg(arg1, arg1.attribution, sup1) if sup1 else arg1 self.arg2 = Arg(arg2, arg2.attribution, sup2) if sup2 else arg2 elif len(args) == 2: - self.arg1, self.arg2 = args + self.arg1, self.arg2 = args else: - raise Exception('Was expecting either 2 or 4 arguments, but got: %d\n%s' % (len(xs), xs)) + raise ValueError('Was expecting either 2 or 4 arguments, ' + 'but got: %d\n%s' % (len(args), args)) def _substr(self): return PdtbItem._substr(self) + class ExplicitRelationFeatures(PdtbItem): def __init__(self, attribution, connhead): - assert(isinstance(attribution, Attribution)) - assert(isinstance(connhead, Connective)) + assert isinstance(attribution, Attribution) + assert isinstance(connhead, Connective) self.attribution = attribution - self.connhead = connhead + self.connhead = connhead @classmethod def _init_copy(cls, self, other): cls.__init__(self, other.attribution, other.connhead) + class ImplicitRelationFeatures(PdtbItem): def __init__(self, attribution, connective1, connective2=None): - assert(isinstance(attribution, Attribution)) - assert(isinstance(connective1, Connective)) + assert isinstance(attribution, Attribution) + assert isinstance(connective1, Connective) if connective2: - assert(isinstance(connective2, Connective)) + assert isinstance(connective2, Connective) self.attribution = attribution self.connective1 = connective1 self.connective2 = connective2 @@ -216,20 +230,22 @@ def _init_copy(cls, self, other): cls.__init__(self, other.attribution, other.connective1, other.connective2) + class AltLexRelationFeatures(PdtbItem): def __init__(self, attribution, semclass1, semclass2): - assert(isinstance(attribution, Attribution)) - assert(isinstance(semclass1, SemClass)) + assert isinstance(attribution, Attribution) + assert isinstance(semclass1, SemClass) if semclass2: - assert(isinstance(semclass2, SemClass)) + assert isinstance(semclass2, SemClass) self.attribution = attribution - self.semclass1 = semclass1 - self.semclass2 = semclass2 + self.semclass1 = semclass1 + self.semclass2 = semclass2 @classmethod def _init_copy(cls, self, other): cls.__init__(self, other.attribution, other.semclass1, other.semclass2) + class ExplicitRelation(Selection, ExplicitRelationFeatures, Relation): def __init__(self, selection, features, args): Relation.__init__(self, args) @@ -247,7 +263,8 @@ def __init__(self, infsite, features, args): ImplicitRelationFeatures._init_copy(self, features) def _substr(self): - return Relation._substr(self) + return Relation._substr(self) + class AltLexRelation(Selection, AltLexRelationFeatures, Relation): def __init__(self, selection, features, args): @@ -256,7 +273,8 @@ def __init__(self, selection, features, args): AltLexRelationFeatures._init_copy(self, features) def _substr(self): - return Relation._substr(self) + return Relation._substr(self) + class EntityRelation(InferenceSite, Relation): def __init__(self, infsite, args): @@ -264,7 +282,8 @@ def __init__(self, infsite, args): InferenceSite._init_copy(self, infsite) def _substr(self): - return Relation._substr(self) + return Relation._substr(self) + class NoRelation(InferenceSite, Relation): def __init__(self, infsite, args): @@ -272,7 +291,7 @@ def __init__(self, infsite, args): InferenceSite._init_copy(self, infsite) def _substr(self): - return Relation._substr(self) + return Relation._substr(self) # --------------------------------------------------------------------- # not-quite-lexing @@ -288,11 +307,12 @@ def _substr(self): # provide some abstractions over tokens, we could maybe simplify the # parser a lot... which could in turn make it faster? + class _Char(object): def __init__(self, value, abspos, line, relpos): - self.value = value + self.value = value self.abspos = abspos - self.line = line + self.line = line self.relpos = relpos def __eq__(self, other): @@ -307,50 +327,58 @@ def __repr__(self): char = 'SP' elif self.value == '\t': char = 'TAB' - return '[%s] %d (line: %d col: %d)' % (char, self.abspos, self.line, self.relpos) + return '[%s] %d (line: %d col: %d)' % ( + char, self.abspos, self.line, self.relpos) + def _annotate_production(s): return s + def _annotate_debug(s): """ Add line/col char number """ def tokens(): line = 1 - col = 1 - pos = 1 + col = 1 + pos = 1 for c in StringIO(s).read(): yield _Char(c, pos, line, col) pos += 1 if c == '\n': line += 1 - col = 1 + col = 1 else: col += 1 return list(tokens()) + # --------------------------------------------------------------------- # funcparserlib utilities # --------------------------------------------------------------------- +_DEBUG = 0 # turn this on to get line number hints +_const = lambda x: lambda _: x +_unarg = lambda f: lambda x: f(*x) -_DEBUG = 0 # turn this on to get line number hints -_const = lambda x: lambda _: x -_unarg = lambda f: lambda x: f(*x) def _cons(pair): head, tail = pair return [head] + tail + def _mkstr_debug(x): return "".join(c.value for c in x) + def _mkstr_production(x): return "".join(x) -_any = fp.some(_const(True)) -def _intersperse(d,xs): +_any = fp.some(_const(True)) + + +def _intersperse(d, xs): """ a -> [a] -> [a] """ @@ -362,6 +390,7 @@ def _intersperse(d,xs): xs2.append(x) return xs2 + def _not_followed_by(p): """Parser(a, b) -> Parser(a, b) @@ -380,6 +409,7 @@ def _helper(tokens, s): _helper.name = u'not_followed_by{ %s }' % p.name return _helper + def _skipto(p): """Parser(a, b) -> Parser(a, [a]) @@ -391,7 +421,7 @@ def _skipto(p): def _helper(tokens, s): """Iterative implementation preventing the stack overflow.""" res = [] - s2 = s + s2 = s while s2.pos < len(tokens): try: (v, s3) = p.run(tokens, s2) @@ -405,27 +435,35 @@ def _helper(tokens, s): _helper.name = u'{ skip_to %s }' % p.name return _helper + def _skipto_mkstr(p): return _skipto(p) >> _mkstr + def _satisfies_debug(fn): - return fp.some(lambda t:fn(t.value)) + return fp.some(lambda t: fn(t.value)) + def _satisfies_production(fn): return fp.some(fn) + def _oneof(xs): return _satisfies(lambda x: x in xs) + def _sepby(delim, p): return p + fp.many(fp.skip(delim) + p) >> _cons + def _sequence(ps): return reduce(lambda x, y: x + y, ps) + def _many_char(fn): return fp.many(_satisfies(fn)) >> _mkstr + def _noise(xs): """String -> Parser(a, ()) @@ -436,15 +474,15 @@ def _helper(tokens, s): """Iterative implementation preventing the stack overflow.""" res = [] start = s.pos - end = start + len(xs) - toks = tokens[start : end] + end = start + len(xs) + toks = tokens[start:end] if _DEBUG: - vals = [ t.value for t in toks ] + vals = [t.value for t in toks] else: vals = toks if vals == xs: pos = s.pos + len(xs) - s2 = fp.State(pos, max(pos, s.max)) + s2 = fp.State(pos, max(pos, s.max)) return fp._Ignored(()), s2 else: raise fp.NoParseError(u'Did not match literal ' + xs, s) @@ -453,28 +491,29 @@ def _helper(tokens, s): return _helper if _DEBUG: - _annotate = _annotate_debug - _mkstr = _mkstr_debug + _annotate = _annotate_debug + _mkstr = _mkstr_debug _satisfies = _satisfies_debug else: - _annotate = _annotate_production - _mkstr = _mkstr_production + _annotate = _annotate_production + _mkstr = _mkstr_production _satisfies = _satisfies_production + # --------------------------------------------------------------------- # elementary parts # --------------------------------------------------------------------- - -_nat = fp.oneplus(_satisfies(lambda c: c.isdigit())) >> (lambda x:int(_mkstr(x))) -_nl = fp.skip(_oneof("\r\n")) +_nat = fp.oneplus(_satisfies(lambda c: c.isdigit())) >> (lambda x: int(_mkstr(x))) +_nl = fp.skip(_oneof("\r\n")) _comma = fp.skip(_oneof(",")) _semicolon = fp.skip(_oneof(";")) -_fullstop = fp.skip(_oneof(".")) +_fullstop = fp.skip(_oneof(".")) # horizontal only -_sp = fp.skip(_many_char(lambda x:x not in "\r\n" and x.isspace())) -_allsp = fp.skip(_many_char(lambda x:x.isspace())) -_alphanum_str = _many_char(lambda x:x.isalnum()) -_eof = fp.skip(fp.finished) +_sp = fp.skip(_many_char(lambda x: x not in "\r\n" and x.isspace())) +_allsp = fp.skip(_many_char(lambda x: x.isspace())) +_alphanum_str = _many_char(lambda x: x.isalnum()) +_eof = fp.skip(fp.finished) + class _OptionalBlock: """ @@ -487,7 +526,8 @@ class _OptionalBlock: """ def __init__(self, p, avoid=None): self.avoid = avoid - self.p = p + self.p = p + def _words(ps): """ @@ -495,6 +535,7 @@ def _words(ps): """ return _sequence(_intersperse(_sp, ps)) + def _lines(ps): if not ps: raise Exception('_lines must be called with at least one parser') @@ -505,41 +546,44 @@ def _prefix_nl(y): return _nl + y def _next(y, prefix=_prefix_nl): - if isinstance(y,_OptionalBlock): + if isinstance(y, _OptionalBlock): if y.avoid: # stop parsing if we see the distractor distractor = prefix(y.avoid) - p_next = _not_followed_by(distractor) + prefix(y.p) + p_next = _not_followed_by(distractor) + prefix(y.p) else: - p_next = prefix(y.p) + p_next = prefix(y.p) return fp.maybe(p_next) else: return prefix(y) - def _combine(x,y): + def _combine(x, y): return x + _next(y) return reduce(_combine, ps) + def _section_begin(t): return _noise('____' + t + '____') + def _subsection_begin(t): return _noise('#### ' + t + ' ####') + _subsection_end = _noise('##############') -_bar = _noise('_' * 56) +_bar = _noise('_' * 56) _span = _nat + _noise('..') + _nat >> tuple _gorn = _sepby(_comma, _nat) >> GornAddress _StringPosition = _nat _SentenceNumber = _nat + # --------------------------------------------------------------------- # selections - funcparserlib # --------------------------------------------------------------------- - -_SpanList = _sepby(_semicolon, _span) +_SpanList = _sepby(_semicolon, _span) _GornAddressList = _sepby(_semicolon, _gorn) _RawText = _lines([_subsection_begin('Text'), _skipto_mkstr(_nl + _subsection_end)]) @@ -550,13 +594,13 @@ def _subsection_begin(t): _inferenceSite =\ _lines([_StringPosition, _SentenceNumber]) >> _unarg(InferenceSite) + # --------------------------------------------------------------------- # features # --------------------------------------------------------------------- - -_Source = _alphanum_str -_Type = _alphanum_str -_Polarity = _alphanum_str +_Source = _alphanum_str +_Type = _alphanum_str +_Polarity = _alphanum_str _Determinacy = _alphanum_str _attributionCoreFeatures =\ @@ -570,46 +614,54 @@ def _subsection_begin(t): # Expansion.Alternative.Chosen alternative => # Expansion / Alternative / "Chosen alternative " -_SemanticClassWord = _many_char(lambda x:x in [' ', '-'] or x.isalnum()) +_SemanticClassWord = _many_char(lambda x: x in [' ', '-'] or x.isalnum()) _SemanticClassN = _sepby(_fullstop, _SemanticClassWord) >> SemClass _SemanticClass1 = _SemanticClassN _SemanticClass2 = _SemanticClassN -_semanticClass = _SemanticClass1 + fp.maybe(_sp + _comma + _sp + _SemanticClass2) +_semanticClass = _SemanticClass1 + fp.maybe( + _sp + _comma + _sp + _SemanticClass2) # always followed by a comma (yeah, a bit clunky) _ConnHead = _skipto_mkstr(_comma) -_Conn1 = _ConnHead -_Conn2 = _ConnHead +_Conn1 = _ConnHead +_Conn2 = _ConnHead -def _mkConnective(c,semclasses): + +def _mkConnective(c, semclasses): return Connective(c, *semclasses) + _connHeadSemanticClass = _ConnHead + _sp + _semanticClass >> _unarg(_mkConnective) -_conn1SemanticClass = _Conn1 + _sp + _semanticClass >> _unarg(_mkConnective) -_conn2SemanticClass = _Conn2 + _sp + _semanticClass >> _unarg(_mkConnective) +_conn1SemanticClass = _Conn1 + _sp + _semanticClass >> _unarg(_mkConnective) +_conn2SemanticClass = _Conn2 + _sp + _semanticClass >> _unarg(_mkConnective) + # --------------------------------------------------------------------- # arguments and supplementary information # --------------------------------------------------------------------- - def _Arg(name): return _section_begin(name.capitalize()) + def _Sup(name): return _section_begin(name.capitalize()) + def _arg(name): p = _lines([_Arg(name), _selection, _attributionFeatures]) >> _unarg(Arg) return p + def _arg_no_features(name): p = _lines([_Arg(name), _selection]) >> Arg return p + def _sup(name): p = _lines([_Sup(name), _selection]) >> Sup return p + # this is a bit yucky because I don't really know how to express # optional first blocks and make sure I handle the intervening # newlines correctly @@ -619,29 +671,29 @@ def _mk_args_and_sups(): _OptionalBlock(_sup('sup2'))] with_sup1 = _lines([_sup('sup1')] + rest) >> tuple - sans_sup1 = _lines(rest) >> (lambda xs : tuple([None] + list(xs))) - return with_sup1 | sans_sup1 # yuck :-( + sans_sup1 = _lines(rest) >> (lambda xs: tuple([None] + list(xs))) + return with_sup1 | sans_sup1 # yuck :-( _args_and_sups = _mk_args_and_sups() _args_only =\ _lines([_arg_no_features('arg1'), _arg_no_features('arg2')]) >> tuple + # --------------------------------------------------------------------- # relations # --------------------------------------------------------------------- - __Explicit = 'Explicit' -__Implict = 'Implicit' -__AltLex = 'AltLex' -__EntRel = 'EntRel' -__NoRel = 'NoRel' +__Implict = 'Implicit' +__AltLex = 'AltLex' +__EntRel = 'EntRel' +__NoRel = 'NoRel' _Explicit = _section_begin(__Explicit) -_Implict = _section_begin(__Implict) -_AltLex = _section_begin(__AltLex) -_EntRel = _section_begin(__EntRel) -_NoRel = _section_begin(__NoRel) +_Implict = _section_begin(__Implict) +_AltLex = _section_begin(__AltLex) +_EntRel = _section_begin(__EntRel) +_NoRel = _section_begin(__NoRel) _explicitRelationFeatures =\ _lines([_attributionFeatures, _connHeadSemanticClass])\ @@ -649,7 +701,7 @@ def _mk_args_and_sups(): _altLexRelationFeatures =\ _lines([_attributionFeatures, _semanticClass])\ - >> (lambda x:AltLexRelationFeatures(x[0], *x[1])) + >> (lambda x: AltLexRelationFeatures(x[0], *x[1])) _afterImplicitRelationFeatures =\ _section_begin('Arg1') | _section_begin('Sup1') @@ -681,45 +733,50 @@ def _mk_args_and_sups(): _lines([_inferenceSite, _args_only])\ >> _unarg(NoRelation) -_relationParts=\ - [(__Explicit, _explicitRelation), - (__Implict, _implicitRelation), - (__AltLex, _altLexRelation), - (__EntRel, _entityRelation), - (__NoRel, _noRelation), - ] +_relationParts = [ + (__Explicit, _explicitRelation), + (__Implict, _implicitRelation), + (__AltLex, _altLexRelation), + (__EntRel, _entityRelation), + (__NoRel, _noRelation), +] + def _relationBody(ty, core): return _lines([_section_begin(ty), core]) + def _orRels(rs): """ R1 or R2 or .. RN """ - cores = [ _relationBody(*r) for r in rs ] + cores = [_relationBody(*r) for r in rs] return _lines([_bar, reduce(lambda x, y: x | y, cores), _bar]) + def _oneRel(ty, core): return _lines([_bar, _relationBody(ty, core), _bar]) -_relation = _orRels(_relationParts) + +_relation = _orRels(_relationParts) _relationList = _sepby(_nl, _relation) -_pdtbRelation = _relation + _allsp + _eof -_pdtbFile = _relationList + _allsp + _eof +_pdtbRelation = _relation + _allsp + _eof +_pdtbFile = _relationList + _allsp + _eof + # --------------------------------------------------------------------- # tests and examples # --------------------------------------------------------------------- - def split_relations(s): frame = r'________________________________________________________\n' +\ r'.*?' +\ r'________________________________________________________' return re.findall(frame, s, re.DOTALL) + def parse_relation(s): """ Parse a single relation or throw a ParseException. @@ -733,16 +790,23 @@ def parse_relation(s): parser = _oneRel(rtype, rules[rtype]) + _eof return parser.parse(_annotate(s)) + def parse(path): - """ - Parse a single .pdtb file and return the list of relations found - within + """Retrieve the list of relations found in a single .pdtb file. + + Parameters + ---------- + path : string + Path to the .pdtb file - :rtype: [Relation] + Returns + ------- + relations : list of Relation + List of relations found. """ - doc = codecs.open(path, 'r', 'iso8859-1').read() + doc = codecs.open(path, 'r', 'iso8859-1').read() return _pdtbFile.parse(_annotate(doc)) # alternatively: using a regular expression to split into relations # and parsing each relation separately - perhaps more robust? - #splits = split_relations(doc) - #return [ parse_relation(s) for s in splits ] + # splits = split_relations(doc) + # return [ parse_relation(s) for s in splits ] diff --git a/educe/pdtb/pdtbx.py b/educe/pdtb/pdtbx.py index b1ab991..f4a8faa 100644 --- a/educe/pdtb/pdtbx.py +++ b/educe/pdtb/pdtbx.py @@ -19,125 +19,151 @@ * implicitRelations can have multiple connectives """ -import xml.etree.cElementTree as ET # python 2.5 and later +import xml.etree.cElementTree as ET # python 2.5 and later -import educe.pdtb.parse as pdtb import educe.pdtb.parse as ty -from educe.internalutil import on_single_element, EduceXmlException, indent_xml +from educe.internalutil import (on_single_element, EduceXmlException, + indent_xml) + # --------------------------------------------------------------------- # XML to internal structure # --------------------------------------------------------------------- - def _read_GornAddressList(attr): return [ty.GornAddress([int(y) for y in x.split(',')]) for x in attr.split(';')] + def _read_SpanList(attr): - return [tuple(map(int,x.split('..'))) for x in attr.split(';')] + return [tuple([int(y) for y in x.split('..')]) for x in attr.split(';')] + def _read_SemClass(attr): return ty.SemClass(attr.split('.')) + def _read_Selection(node): attr = node.attrib - return ty.Selection(span = _read_SpanList(attr['spanList']), - gorn = _read_GornAddressList(attr['gornList']), - text = on_single_element(node, None, lambda x:x.text, 'text')) + return ty.Selection(span=_read_SpanList(attr['spanList']), + gorn=_read_GornAddressList(attr['gornList']), + text=on_single_element(node, None, lambda x: x.text, + 'text')) + def _read_InferenceSite(node): attr = node.attrib - return ty.InferenceSite(strpos = int(attr['strpos']), - sentnum = int(attr['sentnum'])) + return ty.InferenceSite(strpos=int(attr['strpos']), + sentnum=int(attr['sentnum'])) + def _read_Connective(node): attr = node.attrib semclass1_ = attr['semclass1'] - semclass2_ = attr.get('semclass2', None) # optional - semclass1 = _read_SemClass(semclass1_) - semclass2 = _read_SemClass(semclass2_) if semclass2_ else None - return ty.Connective(text = attr['text'], - semclass1 = semclass1, - semclass2 = semclass2) + semclass2_ = attr.get('semclass2', None) # optional + semclass1 = _read_SemClass(semclass1_) + semclass2 = _read_SemClass(semclass2_) if semclass2_ else None + return ty.Connective(text=attr['text'], + semclass1=semclass1, + semclass2=semclass2) + def _read_Attribution(node): - attr = node.attrib + attr = node.attrib selection = on_single_element(node, (), _read_Selection, 'selection') - return ty.Attribution(polarity = attr['polarity'], - determinacy = attr['determinacy'], - type = attr['type'], - source = attr['source'], - selection = None if selection is () else selection) + return ty.Attribution(polarity=attr['polarity'], + determinacy=attr['determinacy'], + type=attr['type'], + source=attr['source'], + selection=(None if selection is () else selection)) + def _read_Sup(node): return ty.Sup(_read_Selection(node)) + def _read_Arg(node): - sup = on_single_element(node, (), _read_Sup, 'sup') + sup = on_single_element(node, (), _read_Sup, 'sup') attribution = on_single_element(node, (), _read_Attribution, 'attribution') - return ty.Arg(selection = _read_Selection(node), - attribution = None if attribution is () else attribution, - sup = None if sup is () else sup) + return ty.Arg(selection=_read_Selection(node), + attribution=(None if attribution is () else attribution), + sup=(None if sup is () else sup)) + def _read_Args(node): - args=node.findall('arg') + args = node.findall('arg') if len(args) != 2: - raise EduceXmlException('Was expecting exactly two arguments (got %d)' % len(args)) - return tuple(map(_read_Arg, args)) + raise EduceXmlException('Was expecting exactly two arguments ' + '(got %d)' % len(args)) + return tuple([_read_Arg(x) for x in args]) + def _read_ExplicitRelationFeatures(node): - attribution = on_single_element(node, None, _read_Attribution, 'attribution') - connhead = on_single_element(node, None, _read_Connective, 'connhead') - return ty.ExplicitRelationFeatures(attribution = attribution, - connhead = connhead) + attribution = on_single_element(node, None, _read_Attribution, + 'attribution') + connhead = on_single_element(node, None, _read_Connective, 'connhead') + return ty.ExplicitRelationFeatures(attribution=attribution, + connhead=connhead) + def _read_ExplicitRelation(node): - return ty.ExplicitRelation(selection = _read_Selection(node), - features = _read_ExplicitRelationFeatures(node), - args = _read_Args(node)) + return ty.ExplicitRelation(selection=_read_Selection(node), + features=_read_ExplicitRelationFeatures(node), + args=_read_Args(node)) + def _read_ImplicitRelationFeatures(node): connectives = node.findall('connective') if len(connectives) == 0: - raise EduceXmlException('Was expecting at least one connective (got none)') + raise EduceXmlException('Was expecting at least one connective ' + '(got none)') elif len(connectives) > 2: - raise EduceXmlException('Was expecting no more than two connectives (got %d)' % len(connectives)) + raise EduceXmlException('Was expecting no more than two connectives ' + '(got %d)' % len(connectives)) - attribution = on_single_element(node, None, _read_Attribution, 'attribution') + attribution = on_single_element(node, None, _read_Attribution, + 'attribution') connective1 = _read_Connective(connectives[0]) - connective2 = _read_Connective(connectives[1]) if len(connectives) == 2 else None - return ty.ImplicitRelationFeatures(attribution = attribution, - connective1 = connective1, - connective2 = connective2) + connective2 = (_read_Connective(connectives[1]) if len(connectives) == 2 + else None) + return ty.ImplicitRelationFeatures(attribution=attribution, + connective1=connective1, + connective2=connective2) + def _read_ImplicitRelation(node): - return ty.ImplicitRelation(infsite = _read_InferenceSite(node), - features = _read_ImplicitRelationFeatures(node), - args = _read_Args(node)) + return ty.ImplicitRelation(infsite=_read_InferenceSite(node), + features=_read_ImplicitRelationFeatures(node), + args=_read_Args(node)) + def _read_AltLexRelationFeatures(node): - attribution = on_single_element(node, None, _read_Attribution, 'attribution') - attr = node.attrib - semclass1_ = attr['semclass1'] - semclass2_ = attr.get('semclass2', None) # optional - semclass1 = _read_SemClass(semclass1_) - semclass2 = _read_SemClass(semclass2_) if semclass2_ else None - return ty.AltLexRelationFeatures(attribution = attribution, - semclass1 = semclass1, - semclass2 = semclass2) + attribution = on_single_element(node, None, _read_Attribution, + 'attribution') + attr = node.attrib + semclass1_ = attr['semclass1'] + semclass2_ = attr.get('semclass2', None) # optional + semclass1 = _read_SemClass(semclass1_) + semclass2 = _read_SemClass(semclass2_) if semclass2_ else None + return ty.AltLexRelationFeatures(attribution=attribution, + semclass1=semclass1, + semclass2=semclass2) + def _read_AltLexRelation(node): - return ty.AltLexRelation(selection = _read_Selection(node), - features = _read_AltLexRelationFeatures(node), - args = _read_Args(node)) + return ty.AltLexRelation(selection=_read_Selection(node), + features=_read_AltLexRelationFeatures(node), + args=_read_Args(node)) + def _read_EntityRelation(node): - return ty.EntityRelation(infsite = _read_InferenceSite(node), - args = _read_Args(node)) + return ty.EntityRelation(infsite=_read_InferenceSite(node), + args=_read_Args(node)) + def _read_NoRelation(node): - return ty.NoRelation(infsite = _read_InferenceSite(node), - args = _read_Args(node)) + return ty.NoRelation(infsite=_read_InferenceSite(node), + args=_read_Args(node)) + def read_Relation(node): tag = node.tag @@ -152,62 +178,75 @@ def read_Relation(node): elif tag == 'noRelation': return _read_NoRelation(node) else: - raise EduceXmlException("Don't know how to read relation with name %s" % tag) + raise EduceXmlException("Don't know how to read relation with name " + "%s" % tag) + def read_Relations(node): return [read_Relation(x) for x in node] + def read_pdtbx_file(filename): tree = ET.parse(filename) return read_Relations(tree.getroot()) + # --------------------------------------------------------------------- # internal structure to XML # --------------------------------------------------------------------- - def _Selection_xml(itm, name='selection'): elm = ET.Element(name) txt = ET.SubElement(elm, 'text') txt.text = itm.text - elm.attrib =\ - {'gornList' : _GornAddressList_xml(itm.gorn), - 'spanList' : _SpanList_xml(itm.span) } + elm.attrib = { + 'gornList': _GornAddressList_xml(itm.gorn), + 'spanList': _SpanList_xml(itm.span) + } return elm + def _InferenceSite_xml(itm, name='inferenceSite'): elm = ET.Element(name) - elm.attrib =\ - {'strpos' : str(itm.strpos), - 'sentnum' : str(itm.sentnum)} + elm.attrib = { + 'strpos': str(itm.strpos), + 'sentnum': str(itm.sentnum) + } return elm def _GornAddressList_xml(itm): - return ";".join(map(_GornAddress_xml, itm)) + return ";".join([_GornAddress_xml(x) for x in itm]) + def _SpanList_xml(itm): - return ";".join(map(_Span_xml, itm)) + return ";".join([_Span_xml(x) for x in itm]) + def _GornAddress_xml(itm): - return ",".join(map(str, itm.parts)) + return ",".join([str(x) for x in itm.parts]) + def _Span_xml(itm): return "%d..%d" % itm + def _Attribution_xml(itm): elm = ET.Element('attribution') - elm.attrib = \ - {'polarity':itm.polarity, - 'determinacy':itm.determinacy, - 'source':itm.source, - 'type':itm.type} + elm.attrib = { + 'polarity': itm.polarity, + 'determinacy': itm.determinacy, + 'source': itm.source, + 'type': itm.type + } if itm.selection: elm.append(_Selection_xml(itm.selection)) return elm + def _SemClass_xml(itm): return ".".join(itm.klass) + def _Connective_xml(itm, name='connective'): elm = ET.Element(name) elm.attrib['semclass1'] = _SemClass_xml(itm.semclass1) @@ -216,9 +255,11 @@ def _Connective_xml(itm, name='connective'): elm.attrib['text'] = itm.text return elm + def _Sup_xml(itm): return _Selection_xml(itm, 'sup') + def _Arg_xml(itm): elm = _Selection_xml(itm, 'arg') if itm.attribution: @@ -227,10 +268,12 @@ def _Arg_xml(itm): elm.append(_Sup_xml(itm.sup)) return elm + def _RelationArgsXml(itm): return [_Arg_xml(itm.arg1), _Arg_xml(itm.arg2)] + def _ExplicitRelation_xml(itm): elm = _Selection_xml(itm, 'explicitRelation') elm.append(_Attribution_xml(itm.attribution)) @@ -238,6 +281,7 @@ def _ExplicitRelation_xml(itm): elm.extend(_RelationArgsXml(itm)) return elm + def _ImplicitRelation_xml(itm): elm = _InferenceSite_xml(itm, 'implicitRelation') elm.append(_Attribution_xml(itm.attribution)) @@ -247,6 +291,7 @@ def _ImplicitRelation_xml(itm): elm.extend(_RelationArgsXml(itm)) return elm + def _AltLexRelation_xml(itm): elm = _Selection_xml(itm, 'altLexRelation') elm.attrib['semclass1'] = _SemClass_xml(itm.semclass1) @@ -256,16 +301,19 @@ def _AltLexRelation_xml(itm): elm.extend(_RelationArgsXml(itm)) return elm + def _EntityRelation_xml(itm): elm = _InferenceSite_xml(itm, 'entityRelation') elm.extend(_RelationArgsXml(itm)) return elm + def _NoRelation_xml(itm): elm = _InferenceSite_xml(itm, 'noRelation') elm.extend(_RelationArgsXml(itm)) return elm + def Relation_xml(itm): if isinstance(itm, ty.ExplicitRelation): return _ExplicitRelation_xml(itm) @@ -278,14 +326,18 @@ def Relation_xml(itm): elif isinstance(itm, ty.NoRelation): return _NoRelation_xml(itm) else: - raise Exception("Don't know how to translate relation of type %s" % type(itm)) + raise Exception("Don't know how to translate relation of type " + "%s" % type(itm)) + def Relations_xml(itms): elm = ET.Element('relations') - elm.extend(map(Relation_xml, itms)) + elm.extend([Relation_xml(x) for x in itms]) return elm + def write_pdtbx_file(filename, relations): xml = Relations_xml(relations) indent_xml(xml) - ET.ElementTree(xml).write(filename, encoding='utf-8', xml_declaration=True) + ET.ElementTree(xml).write(filename, encoding='utf-8', + xml_declaration=True) diff --git a/educe/pdtb/tests.py b/educe/pdtb/tests.py index 7765b12..7b5fd94 100644 --- a/educe/pdtb/tests.py +++ b/educe/pdtb/tests.py @@ -2,13 +2,14 @@ import glob import sys import unittest -import xml.etree.cElementTree as ET # python 2.5 and later +import xml.etree.cElementTree as ET # python 2.5 and later import educe.pdtb.parse as p import educe.pdtb.pdtbx as x -from educe.internalutil import indent_xml +from educe.internalutil import indent_xml -ex_txt="""#### Text #### + +ex_txt = """#### Text #### federal thrift regulators ordered it to suspend @@ -18,27 +19,27 @@ dividend payments on its two classes of preferred stock ##############""" -ex_selection="""36..139 +ex_selection = """36..139 0,1,1;2,1 #### Text #### federal thrift regulators ordered it to suspend dividend payments on its two classes of preferred stock ##############""" -ex_implicit_attribution="""#### Features #### +ex_implicit_attribution = """#### Features #### Wr, Comm, Null, Null also, Expansion.Conjunction""" ### -ex_implicit_features="""#### Features #### +ex_implicit_features = """#### Features #### Wr, Comm, Null, Null in particular, Expansion.Restatement.Specification because, Contingency.Cause.Reason""" -ex_attribution1="""#### Features #### +ex_attribution1 = """#### Features #### Ot, Comm, Null, Null""" -ex_attribution2="""#### Features #### +ex_attribution2 = """#### Features #### Ot, Comm, Null, Null 9..35 0,0;0,1,0;0,1,2;0,2 @@ -46,17 +47,17 @@ CenTrust Savings Bank said ##############""" -ex_sup1="""____Sup1____ +ex_sup1 = """____Sup1____ 1730..1799 11,2,3 #### Text #### blop blop split shares ##############""" -ex_implicit_rel=""" +ex_implicit_rel = """ """ -ex_frame="""________________________________________________________ +ex_frame = """________________________________________________________ blah blah bla _____tahueoa______ bop @@ -66,146 +67,157 @@ class PdtbParseTest(unittest.TestCase): def assertParse(self, parser, expected, txt): - parser = parser + p._eof - res = parser.parse(p._annotate(txt)) + parser = parser + p._eof + res = parser.parse(p._annotate(txt)) self.assertEqual(expected, res) def test_skipto(self): expected = 'blah blah blah hooyeah' - txt = expected + ',' + txt = expected + ',' self.assertParse(p._skipto_mkstr(p._comma), expected, txt) def test_many_char(self): expected = 'abc123' - txt = expected - self.assertParse(p._many_char(lambda x:x.isalnum()), expected, txt) + txt = expected + self.assertParse(p._many_char(lambda x: x.isalnum()), expected, txt) def test_lines(self): expected = 'abc' - txt = 'a\nb\nc' - char = lambda x:p._oneof(x) - parser = p._lines([char("a"), char("b"), char("c")]) >> p._mkstr + txt = 'a\nb\nc' + char = lambda x: p._oneof(x) + parser = p._lines([char("a"), char("b"), char("c")]) >> p._mkstr self.assertParse(parser, expected, txt) - parser = p._lines([char("a"), char("b"), p._OptionalBlock(char("c"))]) >> p._mkstr + parser = p._lines([char("a"), char("b"), p._OptionalBlock(char("c"))]) >> p._mkstr self.assertParse(parser, expected, txt) def test_tok(self): - expected = [p._Char('h',1,1,1), - p._Char('i',2,1,2), - p._Char('\n',3,1,3), - p._Char('y',4,2,1), - p._Char('o',5,2,2), - p._Char('u',6,2,3), - ] + expected = [ + p._Char('h', 1, 1, 1), + p._Char('i', 2, 1, 2), + p._Char('\n', 3, 1, 3), + p._Char('y', 4, 2, 1), + p._Char('o', 5, 2, 2), + p._Char('u', 6, 2, 3), + ] tokens = list(p._annotate_debug('hi\nyou')) self.assertEqual(expected, tokens) def test_nat(self): expected = 42 - txt = str(expected) + txt = str(expected) self.assertParse(p._nat, expected, txt) def test_span(self): - expected = (8,12) - txt = '8..12' + expected = (8, 12) + txt = '8..12' self.assertParse(p._span, expected, txt) def test_gorn(self): - expected = p.GornAddress([0,1,5,3]) - txt = ','.join(map(str,expected.parts)) + expected = p.GornAddress([0, 1, 5, 3]) + txt = ','.join(str(x) for x in expected.parts) self.assertParse(p._gorn, expected, txt) def test_span_list(self): - expected = [(8,12),(9,3),(10,39)] - txt = '8..12;9..3;10..39' + expected = [(8, 12), (9, 3), (10, 39)] + txt = '8..12;9..3;10..39' self.assertParse(p._SpanList, expected, txt) def test_text(self): expected = 'federal thrift\n\nregulators ordered it to suspend \n\n####\n\ndividend payments on its two classes of preferred stock ' - txt = ex_txt + txt = ex_txt self.assertParse(p._RawText, expected, txt) def test_selection(self): - expected = p.Selection(span=[(36,139)], - gorn=[p.GornAddress([0,1,1]),p.GornAddress([2,1])], + expected = p.Selection(span=[(36, 139)], + gorn=[p.GornAddress([0, 1, 1]), + p.GornAddress([2, 1])], text='federal thrift regulators ordered it to suspend dividend payments on its two classes of preferred stock') - txt = ex_selection + txt = ex_selection self.assertParse(p._selection, expected, txt) def test_attribution_core(self): expected = ('Wr', 'Comm', 'Null', 'Null') - txt = "Wr, Comm, Null, Null" + txt = "Wr, Comm, Null, Null" self.assertParse(p._attributionCoreFeatures, expected, txt) def test_attribution(self): expected = p.Attribution('Ot', 'Comm', 'Null', 'Null') - txt = ex_attribution1 + txt = ex_attribution1 self.assertParse(p._attributionFeatures, expected, txt) def test_attribution_sel(self): - expected_sel = p.Selection(span=[(9,35)], - gorn=[p.GornAddress([0,0]), - p.GornAddress([0,1,0]), - p.GornAddress([0,1,2]), - p.GornAddress([0,2])], + expected_sel = p.Selection(span=[(9, 35)], + gorn=[p.GornAddress([0, 0]), + p.GornAddress([0, 1, 0]), + p.GornAddress([0, 1, 2]), + p.GornAddress([0, 2])], text='CenTrust Savings Bank said') expected = p.Attribution('Ot', 'Comm', 'Null', 'Null', expected_sel) - txt = ex_attribution2 + txt = ex_attribution2 self.assertParse(p._attributionFeatures, expected, txt) def test_semclass(self): expected = 'Chosen alternative' - txt = expected + txt = expected self.assertParse(p._SemanticClassWord, expected, txt) - expected1 = p.SemClass(['Expansion', 'Alternative', 'Chosen alternative']) - expected = expected1 - txt = 'Expansion.Alternative.Chosen alternative' + expected1 = p.SemClass(['Expansion', 'Alternative', + 'Chosen alternative']) + expected = expected1 + txt = 'Expansion.Alternative.Chosen alternative' self.assertParse(p._SemanticClass1, expected, txt) - expected = (expected1, None) + expected = (expected1, None) self.assertParse(p._semanticClass, expected, txt) expected2 = p.SemClass(['Contingency', 'Cause', 'Result']) - expected = (expected1, expected2) - txt = 'Expansion.Alternative.Chosen alternative, Contingency.Cause.Result' + expected = (expected1, expected2) + txt = 'Expansion.Alternative.Chosen alternative, Contingency.Cause.Result' self.assertParse(p._semanticClass, expected, txt) def test_connective(self): - expected = p.Connective('also', p.SemClass(['Expansion','Conjunction'])) - txt = 'also, Expansion.Conjunction' + expected = p.Connective('also', p.SemClass(['Expansion', + 'Conjunction'])) + txt = 'also, Expansion.Conjunction' self.assertParse(p._conn1SemanticClass, expected, txt) self.assertParse(p._conn2SemanticClass, expected, txt) def test_sup(self): - expected_sel = p.Selection(span=[(1730,1799)], - gorn=[p.GornAddress([11,2,3])], + expected_sel = p.Selection(span=[(1730, 1799)], + gorn=[p.GornAddress([11, 2, 3])], text='blop blop split shares') expected = p.Sup(expected_sel) - txt = ex_sup1 + txt = ex_sup1 self.assertParse(p._sup('sup1'), expected, txt) def test_implicit_features_1(self): expected_attr = p.Attribution('Wr', 'Comm', 'Null', 'Null') - expected_conn = p.Connective('also', p.SemClass(['Expansion','Conjunction'])) - expected = p.ImplicitRelationFeatures(expected_attr, expected_conn, None) - txt = ex_implicit_attribution + expected_conn = p.Connective('also', p.SemClass(['Expansion', + 'Conjunction'])) + expected = p.ImplicitRelationFeatures(expected_attr, expected_conn, + None) + txt = ex_implicit_attribution self.assertParse(p._implicitRelationFeatures, expected, txt) def test_implicit_features_2(self): expected_conn1 = p.Connective('in particular', - p.SemClass(['Expansion','Restatement','Specification'])) + p.SemClass(['Expansion', + 'Restatement', + 'Specification'])) expected_conn2 = p.Connective('because', - p.SemClass(['Contingency','Cause','Reason'])) - expected_attr = p.Attribution('Wr', 'Comm', 'Null', 'Null') - expected = p.ImplicitRelationFeatures(expected_attr, expected_conn1, expected_conn2) - txt = ex_implicit_features + p.SemClass(['Contingency', + 'Cause', + 'Reason'])) + expected_attr = p.Attribution('Wr', 'Comm', 'Null', 'Null') + expected = p.ImplicitRelationFeatures(expected_attr, expected_conn1, + expected_conn2) + txt = ex_implicit_features self.assertParse(p._implicitRelationFeatures, expected, txt) def test_frame(self): expected = [ex_frame] - split = p.split_relations(ex_frame) + split = p.split_relations(ex_frame) self.assertEqual(expected, split) def test(self): @@ -213,9 +225,10 @@ def test(self): xs = p.parse(path) self.assertNotEquals(0, len(xs)) + class PdtbXmlTest(unittest.TestCase): def dump(self, elem): - indent_xml(elem) # ugh, imperative + indent_xml(elem) # ugh, imperative print("", file=sys.stderr) print(ET.tostring(elem, encoding='utf-8'), file=sys.stderr) @@ -225,13 +238,14 @@ def test_gorn(self): self.assertEqual([itm], x._read_GornAddressList(xml)) def test_span_list(self): - itm = [(4,3), (6,7), (9,9)] + itm = [(4, 3), (6, 7), (9, 9)] xml = x._SpanList_xml(itm) self.assertEqual(itm, x._read_SpanList(xml)) def test_selection(self): - itm = p.Selection(span=[(36,139)], - gorn=[p.GornAddress([0,1,1]),p.GornAddress([2,1])], + itm = p.Selection(span=[(36, 139)], + gorn=[p.GornAddress([0, 1, 1]), + p.GornAddress([2, 1])], text='federal thrift regulators ordered it to suspend dividend payments on its two classes of preferred stock') xml = x._Selection_xml(itm) self.assertEqual(itm, x._read_Selection(xml)) @@ -241,24 +255,24 @@ def test_attribution(self): xml = x._Attribution_xml(itm) self.assertEqual(itm, x._read_Attribution(xml)) - itm_sel = p.Selection(span=[(9,35)], - gorn=[p.GornAddress([0,0]), - p.GornAddress([0,1,0]), - p.GornAddress([0,1,2]), - p.GornAddress([0,2])], - text='CenTrust Savings Bank said') + itm_sel = p.Selection(span=[(9, 35)], + gorn=[p.GornAddress([0, 0]), + p.GornAddress([0, 1, 0]), + p.GornAddress([0, 1, 2]), + p.GornAddress([0, 2])], + text='CenTrust Savings Bank said') itm = p.Attribution('Ot', 'Comm', 'Null', 'Null', itm_sel) xml = x._Attribution_xml(itm) self.assertEqual(itm, x._read_Attribution(xml)) def test_connective(self): itm = p.Connective('also', - p.SemClass(['Expansion','Conjunction'])) + p.SemClass(['Expansion', 'Conjunction'])) xml = x._Connective_xml(itm) self.assertEqual(itm, x._read_Connective(xml)) itm = p.Connective('also', - p.SemClass(['Expansion','Conjunction']), + p.SemClass(['Expansion', 'Conjunction']), p.SemClass(['Contingency', 'Cause', 'Result'])) xml = x._Connective_xml(itm) self.assertEqual(itm, x._read_Connective(xml)) diff --git a/educe/pdtb/util/cmd/extract.py b/educe/pdtb/util/cmd/extract.py index ff6f41c..8070a83 100644 --- a/educe/pdtb/util/cmd/extract.py +++ b/educe/pdtb/util/cmd/extract.py @@ -3,6 +3,9 @@ """ Extract features + +2017-01-27 this code is broken ; it relies on stac.keys.KeyGroupWriter +which was deprecated and removed a while back. """ import codecs diff --git a/educe/pdtb/util/cmd/tmp.py b/educe/pdtb/util/cmd/tmp.py index bea152f..d24ca9b 100644 --- a/educe/pdtb/util/cmd/tmp.py +++ b/educe/pdtb/util/cmd/tmp.py @@ -7,11 +7,8 @@ from __future__ import print_function import collections -import sys -from ..args import\ - add_usual_input_args,\ - read_corpus +from ..args import add_usual_input_args, read_corpus NAME = 'tmp' @@ -26,12 +23,15 @@ def config_argparser(parser): add_usual_input_args(parser) parser.set_defaults(func=main) + def sentence_nums(arg): return [x.parts[0] for x in arg.gorn] + def is_multisentential(args): return len(frozenset(sentence_nums(args))) > 1 + def main(args): """ Subcommand main. @@ -50,9 +50,9 @@ def main(args): print("--------------------" * 3) print() for rel in corpus[k]: - #if (len(rel.arg1.span) > 2 or len(rel.arg2.span) > 2): - # print(unicode(rel).encode('utf-8')) - # print() + # if (len(rel.arg1.span) > 2 or len(rel.arg2.span) > 2): + # print(unicode(rel).encode('utf-8')) + # print() if is_multisentential(rel.arg1): counts[k] += 1 if is_multisentential(rel.arg2): @@ -64,5 +64,6 @@ def main(args): total += counts[k] total_args += num_args[k] for k in sorted(corpus): - print("%s: %d/%d multisentential args" % (k.doc, counts[k], num_args[k])) + print("%s: %d/%d multisentential args" % ( + k.doc, counts[k], num_args[k])) print("altogether: %d/%d multisentential args" % (total, total_args)) diff --git a/educe/pdtb/util/cmd/xml_.py b/educe/pdtb/util/cmd/xml_.py index b170a26..2a7ee6e 100644 --- a/educe/pdtb/util/cmd/xml_.py +++ b/educe/pdtb/util/cmd/xml_.py @@ -8,11 +8,9 @@ from __future__ import print_function from educe.pdtb import pdtbx -from ..args import\ - add_usual_input_args, add_usual_output_args,\ - get_output_dir, mk_output_path,\ - announce_output_dir,\ - read_corpus +from ..args import (add_usual_input_args, add_usual_output_args, + get_output_dir, mk_output_path, announce_output_dir, + read_corpus) NAME = 'xml' @@ -41,6 +39,6 @@ def main(args): for k in sorted(corpus): opath = mk_output_path(output_dir, k) + '.pdtbx' pdtbx.write_pdtbx_file(opath, corpus[k]) - #readback = pdtbx.read_pdtbx_file(opath) - #assert(corpus[k] == readback) + # readback = pdtbx.read_pdtbx_file(opath) + # assert(corpus[k] == readback) announce_output_dir(output_dir) From bd2a3e74ee17b80a4e49432462abe0cf08c04b36 Mon Sep 17 00:00:00 2001 From: moreymat Date: Fri, 27 Jan 2017 16:05:49 +0100 Subject: [PATCH 05/44] MAINT educe.stac.{edit,oneoff} minor style --- educe/stac/edit/cmd/insert.py | 10 +++------- educe/stac/edit/cmd/nudge.py | 2 +- educe/stac/edit/cmd/rewrite.py | 6 ++---- educe/stac/oneoff/weave.py | 4 ++-- 4 files changed, 8 insertions(+), 14 deletions(-) diff --git a/educe/stac/edit/cmd/insert.py b/educe/stac/edit/cmd/insert.py index f984531..eabae7b 100644 --- a/educe/stac/edit/cmd/insert.py +++ b/educe/stac/edit/cmd/insert.py @@ -11,14 +11,10 @@ import educe.stac from educe.stac.util.annotate import show_diff -from educe.stac.util.args import\ - (add_usual_input_args, - add_usual_output_args, - announce_output_dir, - get_output_dir) +from educe.stac.util.args import (add_usual_input_args, add_usual_output_args, + announce_output_dir, get_output_dir) from educe.stac.util.output import save_document -from educe.stac.util.doc import\ - compute_renames, move_portion +from educe.stac.util.doc import compute_renames, move_portion from .move import is_requested diff --git a/educe/stac/edit/cmd/nudge.py b/educe/stac/edit/cmd/nudge.py index 4381676..f905277 100644 --- a/educe/stac/edit/cmd/nudge.py +++ b/educe/stac/edit/cmd/nudge.py @@ -103,7 +103,7 @@ def _screen_args(args): % (args.nudge_start, args.nudge_end)) if not args.allow_shove and (args.annotator or args.stage): sys.exit("Use --allow-shove if you really mean to limit " - + "--stage or --annotator") + "--stage or --annotator") if args.stage: if args.stage != 'unannotated' and not args.annotator: sys.exit("--annotator is required unless --stage is unannotated") diff --git a/educe/stac/edit/cmd/rewrite.py b/educe/stac/edit/cmd/rewrite.py index 37bc9d7..be00d25 100644 --- a/educe/stac/edit/cmd/rewrite.py +++ b/educe/stac/edit/cmd/rewrite.py @@ -8,10 +8,8 @@ import copy -from educe.stac.util.args import\ - add_usual_input_args,\ - read_corpus,\ - get_output_dir, announce_output_dir +from educe.stac.util.args import (add_usual_input_args, read_corpus, + get_output_dir, announce_output_dir) from educe.stac.util.output import save_document from educe.stac.context import sorted_first_widest diff --git a/educe/stac/oneoff/weave.py b/educe/stac/oneoff/weave.py index 5125f2e..870b372 100644 --- a/educe/stac/oneoff/weave.py +++ b/educe/stac/oneoff/weave.py @@ -782,8 +782,8 @@ def shift_dialogues(doc_src, doc_res, updates, gen): gturn_idc_end = np.array( [i - 1 for i in gturn_idc[1:]] + [len(turns_src) - 1]) # ... and finally - gturn_src_tid_beg = turns_src_tid[gturn_idc_beg] - gturn_src_tid_end = turns_src_tid[gturn_idc_end] + # gturn_src_tid_beg = turns_src_tid[gturn_idc_beg] + # gturn_src_tid_end = turns_src_tid[gturn_idc_end] # 2. get the identifier of the first and last turn of each dialogue # in _res: these turns and those in between must end up in the same From 3a1e3380b6606483b3defce4f2104b9aeff3891c Mon Sep 17 00:00:00 2001 From: moreymat Date: Fri, 27 Jan 2017 16:18:21 +0100 Subject: [PATCH 06/44] MAINT educe.stac.sanity pylint --- educe/stac/sanity/checks/glozz.py | 4 ++-- educe/stac/sanity/checks/graph.py | 4 ++-- educe/stac/sanity/test_sanity.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/educe/stac/sanity/checks/glozz.py b/educe/stac/sanity/checks/glozz.py index 445037e..c5ed6f5 100644 --- a/educe/stac/sanity/checks/glozz.py +++ b/educe/stac/sanity/checks/glozz.py @@ -190,8 +190,8 @@ def duplicate_annotations(inputs, k): annos = defaultdict(list) for anno in doc.annotations(): annos[anno.local_id()].append(anno) - return [DuplicateItem(doc, contexts, k, v) - for k, v in annos.items() if len(v) > 1] + return [DuplicateItem(doc, contexts, ek, ev) + for ek, ev in annos.items() if len(ev) > 1] # ---------------------------------------------------------------------- # overlaps diff --git a/educe/stac/sanity/checks/graph.py b/educe/stac/sanity/checks/graph.py index 188d368..ba0581c 100644 --- a/educe/stac/sanity/checks/graph.py +++ b/educe/stac/sanity/checks/graph.py @@ -111,8 +111,8 @@ def search_graph_cdu_overlap(inputs, k, gra): for mem in gra.cdu_members(cdu): edu_anno = gra.annotation(mem) containers[edu_anno].append(cdu_anno) - return [CduOverlapItem(doc, contexts, k, v) - for k, v in containers.items() if len(v) > 1] + return [CduOverlapItem(doc, contexts, ek, ev) + for ek, ev in containers.items() if len(ev) > 1] def is_arrow_inversion(gra, _, rel): diff --git a/educe/stac/sanity/test_sanity.py b/educe/stac/sanity/test_sanity.py index f5c6765..c3d08b9 100644 --- a/educe/stac/sanity/test_sanity.py +++ b/educe/stac/sanity/test_sanity.py @@ -142,7 +142,7 @@ def get_id(x): return g.mirror(ids[x.local_id()]) mark = self.edu1_2.local_id() - self.assertEqual(list(map(get_id, [c1, c2])), + self.assertEqual([get_id(y) for y in [c1, c2]], g.containing_cdu_chain(ids[mark])) def test_enclosed(self): From f9976004282f396b96d429630ce2f993a1f05d6a Mon Sep 17 00:00:00 2001 From: moreymat Date: Fri, 27 Jan 2017 16:20:42 +0100 Subject: [PATCH 07/44] MAINT educe.stac.util minor style --- educe/stac/util/cmd/TEMPLATE.py | 3 +-- educe/stac/util/cmd/count_rfc.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/educe/stac/util/cmd/TEMPLATE.py b/educe/stac/util/cmd/TEMPLATE.py index 78236ab..23f57c9 100644 --- a/educe/stac/util/cmd/TEMPLATE.py +++ b/educe/stac/util/cmd/TEMPLATE.py @@ -5,8 +5,7 @@ Put subcommand help text here """ -from ..args import\ - add_usual_input_args, add_usual_output_args +from ..args import add_usual_input_args, add_usual_output_args NAME = 'insert-name-here' diff --git a/educe/stac/util/cmd/count_rfc.py b/educe/stac/util/cmd/count_rfc.py index 4181bfe..1b0959f 100644 --- a/educe/stac/util/cmd/count_rfc.py +++ b/educe/stac/util/cmd/count_rfc.py @@ -129,7 +129,7 @@ def display_power(res): rfc_power = (100 * avg_frontier_size) / nb_edus row.append(rfc_power) tres.append(row) - print(tabulate(tres, headers=col_names, floatfmt='.1f')+'\n') + print(tabulate(tres, headers=col_names, floatfmt='.1f') + '\n') def config_argparser(parser): From abeefef8a5ef96e97c3181f867b773c71f8aa4d3 Mon Sep 17 00:00:00 2001 From: moreymat Date: Fri, 27 Jan 2017 16:37:49 +0100 Subject: [PATCH 08/44] MAINT rename local csv modules to {educe,stac}_csv_format --- educe/learning/{csv.py => educe_csv_format.py} | 0 educe/stac/util/{csv.py => stac_csv_format.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename educe/learning/{csv.py => educe_csv_format.py} (100%) rename educe/stac/util/{csv.py => stac_csv_format.py} (100%) diff --git a/educe/learning/csv.py b/educe/learning/educe_csv_format.py similarity index 100% rename from educe/learning/csv.py rename to educe/learning/educe_csv_format.py diff --git a/educe/stac/util/csv.py b/educe/stac/util/stac_csv_format.py similarity index 100% rename from educe/stac/util/csv.py rename to educe/stac/util/stac_csv_format.py From a361c6bb85e09bd779fa566eb81f3179d90490e8 Mon Sep 17 00:00:00 2001 From: moreymat Date: Fri, 27 Jan 2017 16:38:37 +0100 Subject: [PATCH 09/44] MAINT educe.stac more pylint --- educe/stac/lexicon/markers.py | 57 +++++++++++++++--------------- educe/stac/lexicon/pdtb_markers.py | 8 ++--- educe/stac/lexicon/wordclass.py | 3 +- educe/stac/util/prettifyxml.py | 9 ++--- 4 files changed, 40 insertions(+), 37 deletions(-) diff --git a/educe/stac/lexicon/markers.py b/educe/stac/lexicon/markers.py index 40aa7c2..c50ab3f 100644 --- a/educe/stac/lexicon/markers.py +++ b/educe/stac/lexicon/markers.py @@ -12,10 +12,10 @@ _table = { - "1":{"marker":"connecteur","form":"forme"}, - "2":{"marker":"connective","form":"form"} - } - + "1": {"marker": "connecteur", "form": "forme"}, + "2": {"marker": "connective", "form": "form"} +} + _stopwords = set(u"à et ou pour en".split()) @@ -27,17 +27,19 @@ class Marker: version 1 has type (coord/subord) version 2 has grammatical host and lemma """ - def __init__(self,elmt,version="2",stop=_stopwords): - self._forms = [x.text.strip() for x in elmt.findall(".//%s"%_table[version]["form"])] - self.__dict__.update(elmt.attrib) - # - if version == "2": - self.relations = [x.attrib["relation"] for x in elmt.findall(".//use")] - else: - self.relations = [x.strip() for x in self.relations.split(",")] - self.lemma = self.forms[0] - self.host = None - + def __init__(self, elmt, version="2", stop=_stopwords): + self._forms = [x.text.strip() + for x in elmt.findall(".//%s" % _table[version]["form"])] + self.__dict__.update(elmt.attrib) + # + if version == "2": + self.relations = [x.attrib["relation"] + for x in elmt.findall(".//use")] + else: + self.relations = [x.strip() for x in self.relations.split(",")] + self.lemma = self.forms[0] + self.host = None + def get_forms(self): return self._forms @@ -47,33 +49,32 @@ def get_lemma(self): def get_relations(self): return self.relations + class LexConn: - def __init__(self,infile,version="2",stop=_stopwords): + def __init__(self, infile, version="2", stop=_stopwords): """read lexconn file, version is 1 or 2 """ lex = ET.parse(infile) - markers = [Marker(x,version=version) for x in lex.findall(".//%s"%_table[version]["marker"])] + markers = [Marker(x, version=version) + for x in lex.findall(".//%s" % _table[version]["marker"])] markers = [x for x in markers if x.get_lemma() not in stop] - self._markers = dict([(x.id,x) for x in markers]) + self._markers = dict([(x.id, x) for x in markers]) def __iter__(self): return iter(self._markers.values()) - def get_by_id(self,id): - return self._markers.get(id,None) + def get_by_id(self, id): + return self._markers.get(id, None) - def get_by_form(self,form): + def get_by_form(self, form): return [x for x in self._markers.values() if form in x.get_forms()] - - def get_by_lemma(self,lemma): - return [x for x in self._markers.values() if lemma==x.get_lemma()] + def get_by_lemma(self, lemma): + return [x for x in self._markers.values() if lemma == x.get_lemma()] # tests -if __name__=="__main__": - import sys - +if __name__ == "__main__": infile = sys.argv[1] - lex = LexConn(infile,version=sys.argv[2]) + lex = LexConn(infile, version=sys.argv[2]) diff --git a/educe/stac/lexicon/pdtb_markers.py b/educe/stac/lexicon/pdtb_markers.py index e925ed0..4faaa2f 100755 --- a/educe/stac/lexicon/pdtb_markers.py +++ b/educe/stac/lexicon/pdtb_markers.py @@ -141,13 +141,13 @@ def read_lexicon(filename): Parameters ---------- - filename: string - Path to the lexicon + filename : string + Path to the lexicon. Returns ------- - relations: dict(string, frozenset(Marker)) - Relations and their signalling discourse markers + relations : dict(string, frozenset(Marker)) + Relations and their signalling discourse markers. """ rel2markers = defaultdict(list) # compute the inverse mapping; marker2rels -> rel2markers diff --git a/educe/stac/lexicon/wordclass.py b/educe/stac/lexicon/wordclass.py index 431fb5e..51c91d1 100644 --- a/educe/stac/lexicon/wordclass.py +++ b/educe/stac/lexicon/wordclass.py @@ -40,7 +40,8 @@ class LexEntry(namedtuple("LexEntry", def __new__(cls, word, lex_class, pos, subclass): pos = pos if pos != '??' else None subclass = subclass or None - return super(LexEntry, cls).__new__(cls, word, lex_class, pos, subclass) + return super(LexEntry, cls).__new__(cls, word, lex_class, pos, + subclass) @classmethod def read_entry(cls, line): diff --git a/educe/stac/util/prettifyxml.py b/educe/stac/util/prettifyxml.py index e9204cd..76f0c75 100644 --- a/educe/stac/util/prettifyxml.py +++ b/educe/stac/util/prettifyxml.py @@ -1,7 +1,8 @@ #!/usr/bin/python # -*- coding: utf-8 -*- -'''Function to "prettify" XML: courtesy of http://www.doughellmann.com/PyMOTW/xml/etree/ElementTree/create.html +'''Function to "prettify" XML: courtesy of +http://www.doughellmann.com/PyMOTW/xml/etree/ElementTree/create.html ''' from __future__ import print_function @@ -18,6 +19,6 @@ def prettify(elem, indent=""): return reparsed.toprettyxml(indent=indent) if __name__ == '__main__': - TREE = ElementTree.parse(sys.argv[1]) - ROOT = TREE.getroot() - print(prettify(ROOT)) + tree = ElementTree.parse(sys.argv[1]) + root = tree.getroot() + print(prettify(root)) From de7089836d55d1f87a15aa29898bc2efcbfd0ace Mon Sep 17 00:00:00 2001 From: moreymat Date: Fri, 27 Jan 2017 16:45:21 +0100 Subject: [PATCH 10/44] FIX+MAINT catch up with renamed module, pylint --- educe/stac/learning/features.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/educe/stac/learning/features.py b/educe/stac/learning/features.py index 7e622c2..621293c 100644 --- a/educe/stac/learning/features.py +++ b/educe/stac/learning/features.py @@ -7,12 +7,10 @@ """ from __future__ import absolute_import, print_function -from collections import defaultdict, namedtuple +from collections import defaultdict, namedtuple, Sequence from functools import wraps -from itertools import chain -import collections import copy -import itertools as itr +import itertools import os import re import sys @@ -31,11 +29,10 @@ edus_in_span, turns_in_span) from educe.stac.corpus import (twin_key) -from educe.learning.csv import tune_for_csv +from educe.learning.educe_csv_format import SparseDictReader, tune_for_csv from educe.learning.util import tuple_feature, underscore import educe.corpus import educe.glozz -import educe.learning.csv as educe_csv import educe.stac import educe.stac.lexicon.pdtb_markers as pdtb_markers import educe.stac.graph as stac_gr @@ -154,7 +151,7 @@ def emoticons(tokens): def is_just_emoticon(tokens): "Return true if a sequence of tokens consists of a single emoticon" - if not isinstance(tokens, collections.Sequence): + if not isinstance(tokens, Sequence): raise TypeError("tokens must form a sequence") return bool(emoticons(tokens)) and len(tokens) == 1 @@ -239,7 +236,7 @@ def has_pdtb_markers(markers, tokens): Given a sequence of tagged tokens, return True if any of the given PDTB markers appears within the tokens """ - if not isinstance(tokens, collections.Sequence): + if not isinstance(tokens, Sequence): raise TypeError("tokens must form a sequence") words = [t.word for t in tokens] return pdtb_markers.Marker.any_appears_in(markers, words) @@ -304,7 +301,7 @@ def map_topdown(good, prunable, trees): """ Do topdown search on all these trees, concatenate results. """ - return list(chain.from_iterable( + return list(itertools.chain.from_iterable( tree.topdown(good, prunable) for tree in trees if isinstance(tree, SearchableTree))) @@ -1278,7 +1275,7 @@ def _mk_high_level_dialogues(current): for dia in dialogues: d_edus = edus_in_dialogues[dia] d_relations = {} - for pair in itr.product([FakeRootEDU] + d_edus, d_edus): + for pair in itertools.product([FakeRootEDU] + d_edus, d_edus): rel = relations.get(_id_pair(pair)) if rel is not None: d_relations[pair] = rel @@ -1359,7 +1356,7 @@ def _read_inquirer_lexicon(args): """ inq_txt_file = os.path.join(args.resources, INQUIRER_BASENAME) with open(inq_txt_file) as cin: - creader = educe_csv.SparseDictReader(cin, delimiter='\t') + creader = SparseDictReader(cin, delimiter='\t') words = defaultdict(list) for row in creader: for k in row: From f781c815dc2435660bdd33faf3964b007dd92ec2 Mon Sep 17 00:00:00 2001 From: moreymat Date: Sat, 28 Jan 2017 09:21:17 +0100 Subject: [PATCH 11/44] MAINT educe/*.py pylint, minor fixes for style --- educe/annotation.py | 19 +++-- educe/corpus.py | 43 ++++++----- educe/glozz.py | 3 +- educe/graph.py | 67 +++++++++-------- educe/internalutil.py | 11 ++- educe/tests.py | 162 ++++++++++++++++++++++-------------------- 6 files changed, 165 insertions(+), 140 deletions(-) diff --git a/educe/annotation.py b/educe/annotation.py index b5f429c..0262e2c 100644 --- a/educe/annotation.py +++ b/educe/annotation.py @@ -200,9 +200,14 @@ def __repr__(self): # pylint: disable=no-self-use class Standoff(object): - """ - A standoff object ultimately points to some piece of text. + """A standoff object ultimately points to some piece of text. + The pointing is not necessarily direct though + + Parameters + ---------- + origin : educe.corpus.FileId + FileId of the document supporting this standoff. """ def __init__(self, origin=None): self.origin = origin @@ -212,7 +217,8 @@ def _members(self): Any annotations contained within this annotation. Must return None if is a terminal annotation (not the same - meaning as returning the empty list) + meaning as returning the empty list). + Non-terminal annotations must override this. """ return None @@ -223,12 +229,11 @@ def _terminals(self, seen=None): terminals """ my_members = self._members() - seen = seen or [] if my_members is None: return [self] - else: - return chain.from_iterable([m._terminals(seen + my_members) - for m in my_members if m not in seen]) + seen = seen or [] + return chain.from_iterable([m._terminals(seen + my_members) + for m in my_members if m not in seen]) def text_span(self): """ diff --git a/educe/corpus.py b/educe/corpus.py index ddde1a0..23d88f5 100644 --- a/educe/corpus.py +++ b/educe/corpus.py @@ -17,7 +17,6 @@ # the above. Give us a mapping from FileId to filepaths and we # do the rest. -import sys class FileId: """ @@ -49,14 +48,14 @@ class FileId: :type annotator: string """ def __init__(self, doc, subdoc, stage, annotator): - self.doc=doc - self.subdoc=subdoc - self.stage=stage - self.annotator=annotator + self.doc = doc + self.subdoc = subdoc + self.stage = stage + self.annotator = annotator def __str__(self): - return "%s [%s] %s %s" % (self.doc, self.subdoc, self.stage, self.annotator) - + return "%s [%s] %s %s" % (self.doc, self.subdoc, self.stage, + self.annotator) def _tuple(self): """ @@ -95,13 +94,14 @@ def mk_global_id(self, local_id): parts = [self.doc, self.subdoc, local_id] return "_".join(p for p in parts if p is not None) + class Reader: """ `Reader` provides little more than dictionaries from `FileId` to data. :param rootdir: the top directory of the corpus - :type rootdir: string + :type rootdir: str A potentially useful pattern to apply here is to take a slice of these dictionaries for processing. For example, you might not want @@ -110,24 +110,24 @@ class Reader: .. code-block:: python - reader = Reader(corpus_dir) - files = reader.files() - subfiles = { k:v in files.items() if k.annotator in [ 'Bob', 'Alice' ] } - corpus = reader.slurp(subfiles) + reader = Reader(corpus_dir) + files = reader.files() + subfiles = {k: v in files.items() if k.annotator in ['Bob', 'Alice']} + corpus = reader.slurp(subfiles) Alternatively, having read in the entire corpus, you might be doing processing on various slices of it at a time .. code-block:: python - corpus = reader.slurp() - subcorpus = { k:v in corpus.items() if k.doc == 'pilot14' } + corpus = reader.slurp() + subcorpus = {k: v in corpus.items() if k.doc == 'pilot14'} This is an abstract class; you should use the version from a data-set, eg. `educe.stac.Reader` instead """ - def __init__(self, dir): - self.rootdir=dir + def __init__(self, root): + self.rootdir = root def files(self): """ @@ -152,10 +152,8 @@ def slurp(self, cfiles=None, verbose=False): :param verbose: print what we're reading to stderr :type verbose: bool """ - if cfiles is None: - subcorpus=self.files() - else: - subcorpus=cfiles + subcorpus = (cfiles if cfiles is not None + else self.files()) return self.slurp_subcorpus(subcorpus, verbose) def slurp_subcorpus(self, cfiles, verbose=False): @@ -168,7 +166,6 @@ def filter(self, d, pred): """ Convenience function equivalent to :: - { k:v for k,v in d.items() if pred(k) } + { k: v for k, v in d.items() if pred(k) } """ - return dict([(k,v) for k,v in d.items() if pred(k)]) - + return dict([(k, v) for k, v in d.items() if pred(k)]) diff --git a/educe/glozz.py b/educe/glozz.py index 524f806..72c6e68 100644 --- a/educe/glozz.py +++ b/educe/glozz.py @@ -105,7 +105,8 @@ def glozz_annotation_to_xml(self, tag='annotation', elif tag == 'schema': span_elm = glozz_schema_to_span_xml(self) else: - raise Exception("Don't know how to emit XML for non unit/relation annotations (%s)" % tag) + raise ValueError("Don't know how to emit XML for non unit/relation " + "annotations (%s)" % tag) elm.extend([meta_elm, char_elm, span_elm]) return elm diff --git a/educe/graph.py b/educe/graph.py index 1b849b5..90b40df 100644 --- a/educe/graph.py +++ b/educe/graph.py @@ -102,10 +102,11 @@ import pydot import pygraph.classes.hypergraph as gr -import pygraph.classes.digraph as dgr +import pygraph.classes.digraph as dgr from pygraph.algorithms import accessibility -# pylint: disable=too-few-public-methods, star-args + +# pylint: disable=too-few-public-methods class DuplicateIdException(Exception): '''Condition that arises in inconsistent corpora''' @@ -113,6 +114,7 @@ def __init__(self, duplicate): self.duplicate = duplicate Exception.__init__(self, "Duplicate node id: %s" % duplicate) + class AttrsMixin(): """ Attributes common to both the hypergraph and directed graph @@ -194,6 +196,7 @@ def edgeform(self, x): else: return self.mirror(x) + class Graph(gr.hypergraph, AttrsMixin): """ Hypergraph representation of discourse structure. @@ -274,7 +277,7 @@ def from_doc(cls, corpus, doc_key, nodes = [] edges = [] - edus = [x for x in doc.units if x.local_id() in included and pred(x)] + edus = [x for x in doc.units if x.local_id() in included and pred(x)] rels = [x for x in doc.relations if pred(x)] cdus = [s for s in doc.schemas if pred(s)] @@ -334,7 +337,7 @@ def copy(self, nodeset=None): else: nodes_wanted = set(nodeset) - cdus = [ x for x in nodes_wanted if self.is_cdu(x) ] + cdus = [x for x in nodes_wanted if self.is_cdu(x)] for x in cdus: nodes_wanted.update(self.cdu_members(x, deep=True)) @@ -352,7 +355,7 @@ def is_wanted_edge(e): for e in edges_remaining: if is_wanted_edge(e): edges_wanted.add(e) - nodes_wanted.add(self.mirror(e)) # obligatory node mirror + nodes_wanted.add(self.mirror(e)) # obligatory node mirror edges_remaining.remove(e) keep_growing = True @@ -360,15 +363,15 @@ def is_wanted_edge(e): if n in nodes_wanted: g.add_node(n) for kv in self.node_attributes(n): - g.add_node_attribute(n,kv) + g.add_node_attribute(n, kv) for e in self.hyperedges(): if e in edges_wanted: g.add_hyperedge(e) for kv in self.edge_attributes(e): - g.add_edge_attribute(e,kv) + g.add_edge_attribute(e, kv) for l in self.links(e): - g.link(l,e) + g.link(l, e) return g @@ -424,7 +427,8 @@ def _attrs(self, x): elif self.has_node(x): return self.node_attributes_dict(x) else: - raise Exception('Tried to get attributes of non-existing object ' + str(x)) + raise Exception('Tried to get attributes of non-existing' + ' object ' + str(x)) def relations(self): """ @@ -562,7 +566,7 @@ def _schema_node(self, anno): return self._mk_node(anno, 'CDU', mirrored=True) def _rel_edge(self, anno): - members = [ anno.span.t1, anno.span.t2 ] + members = [anno.span.t1, anno.span.t2] return self._mk_edge(anno, 'rel', members, mirrored=True) def _schema_edge(self, anno): @@ -580,19 +584,21 @@ def _repr_svg_(self): """Ipython magic: show SVG representation of the graph""" dot_string = self._repr_dot_() try: - process = subprocess.Popen(['dot', '-Tsvg'], stdin=subprocess.PIPE, - stdout=subprocess.PIPE, stderr=subprocess.PIPE) + process = subprocess.Popen( + ['dot', '-Tsvg'], stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE) except OSError: raise Exception('Cannot find the dot binary from Graphviz package') out, err = process.communicate(dot_string) if err: - raise Exception('Cannot create svg representation by running dot from string\n:%s' % dot_string) + raise Exception('Cannot create svg representation by running' + ' dot from string\n:%s' % dot_string) return out + # --------------------------------------------------------------------- # visualisation # --------------------------------------------------------------------- - class DotGraph(pydot.Dot): """ A dot representation of this graph for visualisation. @@ -622,12 +628,13 @@ def _complex_rel_attrs(self, anno): Return attributes for (midpoint, to midpoint, from midpoint) """ - midpoint_attrs =\ - {'label': self._rel_label(anno), - 'style': 'dotted', - 'fontcolor': 'blue'} - attrs1 = {'arrowhead' : 'tee', - 'arrowsize' : '0.5'} + midpoint_attrs = { + 'label': self._rel_label(anno), + 'style': 'dotted', + 'fontcolor': 'blue' + } + attrs1 = {'arrowhead': 'tee', + 'arrowsize': '0.5'} attrs2 = {} return (midpoint_attrs, attrs1, attrs2) @@ -698,7 +705,7 @@ def __point(self, logical_target, key): elif self.core.has_edge(proxy_target): proxy_target = self.core.mirror(proxy_target) proxy_target = self._dot_id(proxy_target) - res = (proxy_target, {key:dot_target}) + res = (proxy_target, {key: dot_target}) return res @@ -848,7 +855,7 @@ def __init__(self, anno_graph): self.complex_rels.add(e2) # CDUs which overlap other CDUs - #self.complex_cdus = self.core.cdus() + # self.complex_cdus = self.core.cdus() self.complex_cdus = set() for e in self.core.cdus(): members = self.core.cdu_members(e) @@ -875,8 +882,8 @@ def __init__(self, anno_graph): # Add nodes that have some sort of error condition or another for edge in (self.core.relations() | self.core.cdus()): for node in self.core.links(edge): - if not (self.core.is_edu(node) or\ - self.core.is_relation(node) or\ + if not (self.core.is_edu(node) or + self.core.is_relation(node) or self.core.is_cdu(node)): self._add_edu(node) @@ -894,10 +901,10 @@ def __init__(self, anno_graph): else: self._add_simple_cdu(edge) + # --------------------------------------------------------------------- # enclosure graphs # --------------------------------------------------------------------- - class EnclosureGraph(dgr.digraph, AttrsMixin): """ Caching mechanism for span enclosure. Given an iterable of Annotation, @@ -1000,9 +1007,9 @@ def _mk_node_id(self, anno): def _mk_node(self, anno): # a node is mirrored if there is a also an edge # corresponding to the same object - node_id = self._mk_node_id(anno) + node_id = self._mk_node_id(anno) attrs = {'type': anno.type, - 'annotation' : anno} + 'annotation': anno} return (node_id, attrs) def _add_edge(self, anno1, anno2): @@ -1038,10 +1045,10 @@ def outside(self, annotation): class EnclosureDotGraph(pydot.Dot): def _add_unit(self, node): - anno = self.core.annotation(node) + anno = self.core.annotation(node) label = self._unit_label(anno) - attrs = {'label' : textwrap.fill(label, 30), - 'shape' : 'plaintext'} + attrs = {'label': textwrap.fill(label, 30), + 'shape': 'plaintext'} self.add_node(pydot.Node(node, **attrs)) def _add_edge(self, edge): diff --git a/educe/internalutil.py b/educe/internalutil.py index 5521f1a..d2c6c60 100644 --- a/educe/internalutil.py +++ b/educe/internalutil.py @@ -33,6 +33,7 @@ class EduceXmlException(Exception): def __init__(self, *args, **kw): Exception.__init__(self, *args, **kw) + def on_single_element(root, default, f, name): """ Return @@ -41,17 +42,20 @@ def on_single_element(root, default, f, name): * f(the node) if one element * an exception if more than one """ - nodes=root.findall(name) + nodes = root.findall(name) if len(nodes) == 0: if default is None: - raise EduceXmlException("Expected but did not find any nodes with name %s" % name) + raise EduceXmlException("Expected but did not find any nodes " + "with name %s" % name) else: return default elif len(nodes) > 1: - raise EduceXmlException("Found more than one node with name %s" % name) + raise EduceXmlException("Found more than one node with " + "name %s" % name) else: return f(nodes[0]) + def linebreak_xml(elem): """ Insert a break after each element tag @@ -72,6 +76,7 @@ def linebreak_xml(elem): if not elem.tail or not elem.tail.strip(): elem.tail = i + def indent_xml(elem, level=0): """ From diff --git a/educe/tests.py b/educe/tests.py index 91eac3f..97b20a4 100644 --- a/educe/tests.py +++ b/educe/tests.py @@ -14,7 +14,7 @@ Annotation, Unit, Relation, Schema, Document) import educe.graph as educe -from educe.graph import EnclosureGraph +from educe.graph import EnclosureGraph from educe.util import relative_indices @@ -79,7 +79,6 @@ def test_overlap_empty(self): self.assertOverlap((5, 5), (5, 5), (5, 6), inclusive=True) - class NullAnno(Span, Annotation): def __init__(self, start, end, type="null"): super(NullAnno, self).__init__(start, end) @@ -90,9 +89,9 @@ def local_id(self): return str(self) def __eq__(self, other): - return self.char_start == other.char_start\ - and self.char_end == other.char_end\ - and self.type == other.type + return (self.char_start == other.char_start and + self.char_end == other.char_end and + self.type == other.type) def __ne__(self, other): return not self == other @@ -101,44 +100,45 @@ def __hash__(self): return hash((self.char_start, self.char_end, self.type)) def __repr__(self): - return "%s [%s]" % (super(NullAnno,self).__str__(),self.type) + return "%s [%s]" % (super(NullAnno, self).__str__(), self.type) def __str__(self): return repr(self) + class EnclosureTest(unittest.TestCase): def test_trivial(self): g = EnclosureGraph([]) self.assertEqual(0, len(g.nodes())) def test_singleton(self): - s0 = NullAnno(1,5) + s0 = NullAnno(1, 5) g = EnclosureGraph([s0]) self.assertEqual([s0.local_id()], g.nodes()) self.assertEqual([], g.inside(s0)) self.assertEqual([], g.outside(s0)) def test_simple_enclosure(self): - s1_5 = NullAnno(1,5) - s2_3 = NullAnno(2,3) + s1_5 = NullAnno(1, 5) + s2_3 = NullAnno(2, 3) g = EnclosureGraph([s1_5, s2_3]) self.assertEqual([s2_3], g.inside(s1_5)) self.assertEqual([s1_5], g.outside(s2_3)) def test_indirect_enclosure(self): - s1_5 = NullAnno(1,5,'a') - s2_4 = NullAnno(2,4,'b') - s3_4 = NullAnno(3,4,'c') + s1_5 = NullAnno(1, 5, 'a') + s2_4 = NullAnno(2, 4, 'b') + s3_4 = NullAnno(3, 4, 'c') g = EnclosureGraph([s1_5, s2_4, s3_4]) self.assertEqual([s2_4], g.inside(s1_5)) self.assertEqual([s1_5], g.outside(s2_4)) self.assertEqual([s2_4], g.outside(s3_4)) def test_same_span(self): - s1_5 = NullAnno(1,5,'out') - s2_4a = NullAnno(2,4,'a') - s2_4b = NullAnno(2,4,'b') - s3_4 = NullAnno(3,4,'in') + s1_5 = NullAnno(1, 5, 'out') + s2_4a = NullAnno(2, 4, 'a') + s2_4b = NullAnno(2, 4, 'b') + s3_4 = NullAnno(3, 4, 'in') g = EnclosureGraph([s1_5, s2_4a, s2_4b, s3_4]) self.assertEqual([s2_4a, s2_4b], g.inside(s1_5)) self.assertEqual([s3_4], g.inside(s2_4a)) @@ -158,9 +158,9 @@ def test_indirect_enclosure_untyped(self): """ reduce only pays attention to nodes of different type """ - s_1_5 = NullAnno(1,5) - s_2_4 = NullAnno(2,4) - s_3_4 = NullAnno(3,4) + s_1_5 = NullAnno(1, 5) + s_2_4 = NullAnno(2, 4) + s_3_4 = NullAnno(3, 4) g = EnclosureGraph([s_1_5, s_2_4, s_3_4]) self.assertEqual([s_2_4, s_3_4], g.inside(s_1_5)) self.assertEqual([s_1_5], g.outside(s_2_4)) @@ -170,84 +170,93 @@ def test_indirect_enclosure_untyped(self): # --------------------------------------------------------------------- # annotations # --------------------------------------------------------------------- - class TestUnit(Unit): def __init__(self, id, start, end): Unit.__init__(self, id, Span(start, end), '', {}) + class TestRelation(Relation): def __init__(self, id, start, end): Relation.__init__(self, id, RelSpan(start, end), '', {}) + class TestSchema(Schema): def __init__(self, id, units, relations, schemas): - Schema.__init__(self, id, frozenset(units), frozenset(relations), frozenset(schemas), '', {}) + Schema.__init__(self, id, frozenset(units), frozenset(relations), + frozenset(schemas), '', {}) + class TestDocument(Document): def __init__(self, units, rels, schemas, txt): Document.__init__(self, units, rels, schemas, txt) + def test_members(): - u1 = TestUnit('u1', 2, 4) - u2 = TestUnit('u2', 3, 9) - u3 = TestUnit('distractor', 1,10) - u4 = TestUnit('u4', 12,13) - u5 = TestUnit('u5', 4,12) - u6 = TestUnit('u6', 7,14) - s1 = TestSchema('s1', ['u4','u5','u6'], [], []) - r1 = TestRelation('r1', 's1','u2') - - doc = TestDocument([u1,u2,u3,u4,u5,u6],[r1],[s1], "why hello there!") + u1 = TestUnit('u1', 2, 4) + u2 = TestUnit('u2', 3, 9) + u3 = TestUnit('distractor', 1, 10) + u4 = TestUnit('u4', 12, 13) + u5 = TestUnit('u5', 4, 12) + u6 = TestUnit('u6', 7, 14) + s1 = TestSchema('s1', ['u4', 'u5', 'u6'], [], []) + r1 = TestRelation('r1', 's1', 'u2') + + doc = TestDocument([u1, u2, u3, u4, u5, u6], [r1], [s1], + "why hello there!") assert u1._members() is None - assert sorted(s1._members()) == sorted([u4,u5,u6]) - assert sorted(r1._members()) == sorted([u2,s1]) + assert sorted(s1._members()) == sorted([u4, u5, u6]) + assert sorted(r1._members()) == sorted([u2, s1]) assert u1._terminals() == [u1] - assert sorted(s1._terminals()) == sorted([u4,u5,u6]) - assert sorted(r1._terminals()) == sorted([u2,u4,u5,u6]) + assert sorted(s1._terminals()) == sorted([u4, u5, u6]) + assert sorted(r1._terminals()) == sorted([u2, u4, u5, u6]) doc_sp = doc.text_span() for x in doc.annotations(): sp = x.text_span() assert sp.char_start <= sp.char_end assert sp.char_start >= doc_sp.char_start - assert sp.char_end <= doc_sp.char_end + assert sp.char_end <= doc_sp.char_end + # --------------------------------------------------------------------- # graph # --------------------------------------------------------------------- - class FakeGraph(educe.Graph): """ Stand-in for educe.graph.Graph """ def __init__(self): educe.Graph.__init__(self) - self.corpus = {} + self.corpus = {} self.doc_key = None - self.doc = None + self.doc = None def _add_fake_node(self, anno_id, type): - attrs = { 'type' : type - } + attrs = { + 'type': type + } self.add_node(anno_id) for x in attrs.items(): self.add_node_attribute(anno_id, x) def _add_fake_edge(self, anno_id, type, members): - attrs = { 'type' : type - , 'mirror' : anno_id - } + attrs = { + 'type': type, + 'mirror': anno_id + } self.add_node(anno_id) self.add_edge(anno_id) for x in attrs.items(): self.add_edge_attribute(anno_id, x) self.add_node_attribute(anno_id, x) - for l in members: self.link(l, anno_id) + for l in members: + self.link(l, anno_id) def add_edus(self, *anno_ids): - for anno_id in anno_ids: self.add_edu(str(anno_id)) + for anno_id in anno_ids: + self.add_edu(str(anno_id)) def add_edu(self, anno_id): self._add_fake_node(anno_id, 'EDU') @@ -259,18 +268,19 @@ def add_rel(self, anno_id, node1, node2): self._add_fake_edge(anno_id, 'rel', [str(node1), str(node2)]) def add_cdu(self, anno_id, members): - self._add_fake_edge(anno_id, 'CDU', list(map(str,members))) + self._add_fake_edge(anno_id, 'CDU', list(map(str, members))) + class BasicGraphTest(unittest.TestCase): def test_cdu_members_trivial(self): "trivial CDU membership" gr = FakeGraph() - gr.add_edus(1,2,3) - gr.add_rel('a',1,2) - gr.add_cdu('X',[1,2]) + gr.add_edus(1, 2, 3) + gr.add_rel('a', 1, 2) + gr.add_cdu('X', [1, 2]) - members = gr.cdu_members('X') - expected = frozenset(['1','2']) + members = gr.cdu_members('X') + expected = frozenset(['1', '2']) self.assertEqual(expected, members) # this is probably not a desirable property, but is a consequence @@ -282,20 +292,20 @@ def test_cdu_neighbors(self): "does belong in the same CDU make you a neighbour?" gr = FakeGraph() - gr.add_edus('a1','a2','b') - gr.add_cdu('A',['a1','a2']) + gr.add_edus('a1', 'a2', 'b') + gr.add_cdu('A', ['a1', 'a2']) - ns1 = frozenset(gr.neighbors('a1')) + ns1 = frozenset(gr.neighbors('a1')) expected1 = frozenset(['a2']) self.assertEqual(expected1, ns1) - ns2 = frozenset(gr.neighbors('a2')) + ns2 = frozenset(gr.neighbors('a2')) expected2 = frozenset(['a1']) - self.assertEqual(expected2,ns2) + self.assertEqual(expected2, ns2) - ns3 = frozenset(gr.neighbors('b')) + ns3 = frozenset(gr.neighbors('b')) expected3 = frozenset([]) - self.assertEqual(expected3,ns3) + self.assertEqual(expected3, ns3) def test_copy(self): """ @@ -305,33 +315,33 @@ def test_copy(self): """ gr = FakeGraph() - gr.add_edus(*range(1,4)) - gr.add_edus(*range(10,14)) + gr.add_edus(*range(1, 4)) + gr.add_edus(*range(10, 14)) gr.add_rel('1.2', 1, 2) gr.add_rel('1.3', 1, 3) - gr.add_rel('2.11', 2, 11) # bridge! + gr.add_rel('2.11', 2, 11) # bridge! gr.add_rel('11.12', 11, 12) gr.add_rel('12.13', 11, 12) - gr.add_cdu('X1', [2,3]) - gr.add_cdu('X2', ['X1',1]) # should be copied - gr.add_cdu('Y1', [12,13]) - gr.add_cdu('XY', [1,13]) # should not be copied - - xset2 = set(map(str,[1,2,3])) - gr2 = gr.copy(nodeset=xset2) - self.assertEqual(xset2, gr2.edus()) - self.assertEqual(set(['1.2','1.3']), gr2.relations()) - self.assertEqual(set(['X1', 'X2']), gr2.cdus()) - self.assertEqual(gr.links('X2'), gr2.links('X2')) + gr.add_cdu('X1', [2, 3]) + gr.add_cdu('X2', ['X1', 1]) # should be copied + gr.add_cdu('Y1', [12, 13]) + gr.add_cdu('XY', [1, 13]) # should not be copied + + xset2 = set(map(str, [1, 2, 3])) + gr2 = gr.copy(nodeset=xset2) + self.assertEqual(xset2, gr2.edus()) + self.assertEqual(set(['1.2', '1.3']), gr2.relations()) + self.assertEqual(set(['X1', 'X2']), gr2.cdus()) + self.assertEqual(gr.links('X2'), gr2.links('X2')) # some nonsense copies xset3 = xset2 | set(['X1']) - gr3 = gr.copy(nodeset=xset3) - self.assertEqual(xset2, gr3.edus()) #not xset3 + gr3 = gr.copy(nodeset=xset3) + self.assertEqual(xset2, gr3.edus()) # not xset3 # including CDU should also result in members being included xset4 = set(['X2']) - gr4 = gr.copy(nodeset=xset4) + gr4 = gr.copy(nodeset=xset4) self.assertEqual(xset2, gr4.edus()) self.assertEqual(set(['X1', 'X2']), gr4.cdus()) From 3736c3f75c89d1c77fd62bedb149fbe704422397 Mon Sep 17 00:00:00 2001 From: moreymat Date: Sat, 28 Jan 2017 09:56:42 +0100 Subject: [PATCH 12/44] DOC+MAINT docstring, style --- educe/rst_dt/codra.py | 64 -------------------- educe/rst_dt/dep2con.py | 2 +- educe/rst_dt/deptree.py | 10 +++- educe/rst_dt/parse.py | 4 +- educe/rst_dt/sdrt.py | 21 +++---- educe/rst_dt/tests.py | 103 ++++++++++++++++---------------- educe/stac/context.py | 14 +---- educe/stac/fusion.py | 126 +++++++++++++++++++++++++++------------- educe/stac/postag.py | 15 +++++ educe/stac/rfc.py | 7 ++- educe/stac/tests.py | 8 +-- 11 files changed, 189 insertions(+), 185 deletions(-) delete mode 100644 educe/rst_dt/codra.py diff --git a/educe/rst_dt/codra.py b/educe/rst_dt/codra.py deleted file mode 100644 index 491f2b3..0000000 --- a/educe/rst_dt/codra.py +++ /dev/null @@ -1,64 +0,0 @@ -"""This module provides support for the CODRA discourse parser. -""" - -import codecs -import glob -import os - -from .parse import parse_rst_dt_tree - - -def load_codra_output_files(container_path, level='doc'): - """Load ctrees output by CODRA on the TEST section of RST-WSJ. - - Parameters - ---------- - container_path: string - Path to the main folder containing CODRA's output - - level: {'doc', 'sent'}, optional (default='doc') - Level of decoding: document-level or sentence-level - - Returns - ------- - data: dict - Dictionary that should be akin to a sklearn Bunch, with - interesting keys 'filenames', 'doc_names' and 'rst_ctrees'. - - Notes - ----- - To ensure compatibility with the rest of the code base, doc_names - are automatically added the ".out" extension. This would not work - for fileX documents, but they are absent from the TEST section of - the RST-WSJ treebank. - """ - if level == 'doc': - file_ext = '.doc_dis' - elif level == 'sent': - file_ext = '.sen_dis' - else: - raise ValueError("level {} not in ['doc', 'sent']".format(level)) - - # find all files with the right extension - pathname = os.path.join(container_path, '*{}'.format(file_ext)) - # filenames are sorted by name to avoid having to realign data - # loaded with different functions - filenames = sorted(glob.glob(pathname)) # glob.glob() returns a list - - # find corresponding doc names - doc_names = [os.path.splitext(os.path.basename(filename))[0] + '.out' - for filename in filenames] - - # load the RST trees - rst_ctrees = [] - for filename in filenames: - with codecs.open(filename, 'r', 'utf-8') as f: - # TODO (?) add support for and use RSTContext - rst_ctree = parse_rst_dt_tree(f.read(), None) - rst_ctrees.append(rst_ctree) - - data = dict(filenames=filenames, - doc_names=doc_names, - rst_ctrees=rst_ctrees) - - return data diff --git a/educe/rst_dt/dep2con.py b/educe/rst_dt/dep2con.py index 493919e..4ca2373 100644 --- a/educe/rst_dt/dep2con.py +++ b/educe/rst_dt/dep2con.py @@ -21,7 +21,7 @@ class DummyNuclearityClassifier(object): Parameters ---------- - strategy: str + strategy : str Strategy to use to generate predictions. * "unamb_else_most_frequent": predicts multinuclear when the diff --git a/educe/rst_dt/deptree.py b/educe/rst_dt/deptree.py index a7a3bb7..4d589c4 100644 --- a/educe/rst_dt/deptree.py +++ b/educe/rst_dt/deptree.py @@ -42,7 +42,15 @@ def __init__(self, msg): class RstDepTree(object): - """RST dependency tree""" + """RST dependency tree + + Attributes + ---------- + edus : list of EDU + List of the EDUs of this document. + origin : Document?, optional + TODO + """ def __init__(self, edus=[], origin=None): # FIXME find a clean way to avoid generating a new left padding EDU diff --git a/educe/rst_dt/parse.py b/educe/rst_dt/parse.py index 5698626..a3428aa 100644 --- a/educe/rst_dt/parse.py +++ b/educe/rst_dt/parse.py @@ -260,8 +260,8 @@ def walk(subtree, posinfo=PosInfo(text=0, edu=0)): match = _lw_type_re.match(treenode(subtree)) if not match: - raise RSTTreeException("Missing nuclearity annotation in ", - subtree) + raise RSTTreeException( + "Missing nuclearity annotation in " + str(subtree)) nuclearity = _lw_nuc_map[match.group("nuc")] rel = match.group("rel") or "leaf" edu_span = (start.edu, posinfo.edu - 1) diff --git a/educe/rst_dt/sdrt.py b/educe/rst_dt/sdrt.py index 78a2e79..2d91b42 100644 --- a/educe/rst_dt/sdrt.py +++ b/educe/rst_dt/sdrt.py @@ -20,14 +20,15 @@ class CDU: - """A CDU contains one or more discourse units, and tracks relation - instances between its members. - Both CDU and EDU are discourse units. + """Complex Discourse Unit. + + A CDU contains one or more discourse units, and tracks relation + instances between its members. Both CDU and EDU are discourse units. Attributes ---------- members : list of Unit or Scheme - Immediate members of this CDU. + Immediate member units (EDUs and CDUs) of this CDU. rel_insts : list of Relation Relation instances between immediate members of this CDU. @@ -75,7 +76,7 @@ def debug_du_to_tree(m): rtype_str = list(rtypes)[0] if len(rtypes) == 1 else str(rtypes) return Tree(rtype_str, [debug_du_to_tree(x) for x in m.members]) else: - raise Exception("Don't know how to deal with non CDU/EDU") + raise ValueError("Don't know how to deal with non CDU/EDU") def rst_to_glozz_sdrt(rst_tree, annotator='ldc'): @@ -187,20 +188,20 @@ def rst_to_sdrt(tree): if len(tree) == 1: # pre-terminal edu = tree[0] if not isinstance(edu, rst.EDU): - raise Exception("Pre-terminal with non-EDU leaf: %s" % edu) + raise ValueError("Pre-terminal with non-EDU leaf: %s" % edu) return edu else: nuclei = [x for x in tree if x.label().is_nucleus()] satellites = [x for x in tree if x.label().is_satellite()] if len(nuclei) + len(satellites) != len(tree): - raise Exception("Nodes that are neither Nuclei nor " - "Satellites\n%s" % tree) + raise ValueError( + "Nodes that are neither Nuclei nor Satellites\n%s" % tree) if len(nuclei) == 0: - raise Exception("No nucleus:\n%s" % tree) + raise ValueError("No nucleus:\n%s" % tree) elif len(nuclei) > 1: # multi-nuclear chain if satellites: - raise Exception("Multinuclear with satellites:\n%s" % tree) + raise ValueError("Multinuclear with satellites:\n%s" % tree) c_nucs = [rst_to_sdrt(x) for x in nuclei] rtype = nuclei[0].label().rel rel_insts = set(RelInst(n1, n2, rtype) diff --git a/educe/rst_dt/tests.py b/educe/rst_dt/tests.py index 0a2933e..c4ed414 100644 --- a/educe/rst_dt/tests.py +++ b/educe/rst_dt/tests.py @@ -58,20 +58,20 @@ ) """ -TEXT1 = " ".join( - [" ORGANIZING YOUR MATERIALS ", - (" Once you've decided on the kind of paneling you want to install " - "--- and the pattern ---"), - - "some preliminary steps remain", - "before you climb into your working clothes. ", - " You'll need to measure the wall or room to be paneled,", - "estimate the amount of paneling you'll need,", - "buy the paneling,", - ("gather the necessary tools and equipment (see illustration " - "on page 87),"), - - "and even condition certain types of paneling before installation. " +TEXT1 = " ".join([ + " ORGANIZING YOUR MATERIALS ", + (" Once you've decided on the kind of paneling you want to install " + "--- and the pattern ---"), + + "some preliminary steps remain", + "before you climb into your working clothes. ", + " You'll need to measure the wall or room to be paneled,", + "estimate the amount of paneling you'll need,", + "buy the paneling,", + ("gather the necessary tools and equipment (see illustration " + "on page 87),"), + + "and even condition certain types of paneling before installation. " ]) @@ -117,20 +117,21 @@ def test_binarize(self): self.assertTrue(annotation.is_binary(bin_tree)) def test_rst_to_dt(self): - lw_trees = ["(R:rel (S x) (N y))", - - """ - (R:rel - (S x) - (N:rel (N h) (S t))) - """, - - """ - (R:r - (S x) - (N:r (N:r (S t1) (N h)) - (S t2))) - """ + lw_trees = [ + "(R:rel (S x) (N y))", + + """ + (R:rel + (S x) + (N:rel (N h) (S t))) + """, + + """ + (R:r + (S x) + (N:r (N:r (S t1) (N h)) + (S t2))) + """ ] for lstr in lw_trees: @@ -151,27 +152,28 @@ def test_rst_to_dt(self): "edu span equality on " + name) def test_dt_to_rst_order(self): - lw_trees = ["(R:r (N:r (N h) (S r1)) (S r2))", - "(R:r (S:r (S l2) (N l1)) (N h))", - "(R:r (N:r (S l1) (N h)) (S r1))", - """ - (R:r - (N:r - (N:r (S l2) - (N:r (S l1) - (N h))) - (S r1)) - (S r2)) - """, # ((l2 <- l1 <- h) -> r1 -> r2) - """ - (R:r - (N:r - (S l2) - (N:r (N:r (S l1) - (N h)) - (S r1))) - (S r2)) - """, # (l2 <- ((l1 <- h) -> r1)) -> r2 + lw_trees = [ + "(R:r (N:r (N h) (S r1)) (S r2))", + "(R:r (S:r (S l2) (N l1)) (N h))", + "(R:r (N:r (S l1) (N h)) (S r1))", + """ + (R:r + (N:r + (N:r (S l2) + (N:r (S l1) + (N h))) + (S r1)) + (S r2)) + """, # ((l2 <- l1 <- h) -> r1 -> r2) + """ + (R:r + (N:r + (S l2) + (N:r (N:r (S l1) + (N h)) + (S r1))) + (S r2)) + """, # (l2 <- ((l1 <- h) -> r1)) -> r2 ] for lstr in lw_trees: @@ -185,10 +187,12 @@ def test_dt_to_rst_order(self): dep_b = copy.deepcopy(dep) dep_b.deps(0).reverse() rst2b = deptree_to_simple_rst_tree(dep_b) + # TODO assertion on rst2b? dep_c = copy.deepcopy(dep) random.shuffle(dep_c.deps(0)) rst2c = deptree_to_simple_rst_tree(dep_c) + # TODO assertion on rst2c? def test_rst_to_dt_nuclearity_loss(self): """ @@ -229,3 +233,4 @@ def test_rst_to_dt_nuclearity_loss(self): dep1 = RstDepTree.from_simple_rst_tree(rst1) rev1 = deptree_to_simple_rst_tree(dep1) # was:, ['r']) # self.assertEqual(rst0, rev1, "same structure " + tricky) + # TODO restore a meaningful assertion diff --git a/educe/stac/context.py b/educe/stac/context.py index 061c494..d2f3b90 100644 --- a/educe/stac/context.py +++ b/educe/stac/context.py @@ -133,14 +133,8 @@ class Context(object): (may not be present): tokens contained within this EDU """ # pylint: disable=too-many-arguments - def __init__(self, - turn, - tstar, - turn_edus, - dialogue, - dialogue_turns, - doc_turns, - tokens=None): + def __init__(self, turn, tstar, turn_edus, dialogue, dialogue_turns, + doc_turns, tokens=None): self.turn = turn self.tstar = tstar self.turn_edus = turn_edus @@ -221,13 +215,11 @@ def _for_edu(cls, enclosure, doc_turns, doc_tstars, edu): @classmethod def for_edus(cls, doc, postags=None): - """ - Return a dictionary of context objects for each EDU in the document + """Get a dictionary of context objects for each EDU in the doc. Returns ------- contexts: dict(educe.glozz.Unit, Context) - A dictionary with a context For each EDU in the document """ if postags: diff --git a/educe/stac/fusion.py b/educe/stac/fusion.py index c5222e2..55f54f2 100644 --- a/educe/stac/fusion.py +++ b/educe/stac/fusion.py @@ -1,9 +1,9 @@ -"""Somewhat higher level representation of STAC documents -than the usual Glozz layer. +"""Somewhat higher level representation of STAC documents than the usual +Glozz layer. Note that this is a relatively recent addition to Educe. -Up to the time of this writing (2015-03), we had two options -for dealing with STAC: +Up to the time of this writing (2015-03), we had two options for dealing +with STAC: * manually manipulating glozz objects via educe.annotation * dealing with some high-level but not particularly helpful @@ -19,18 +19,18 @@ This has always been a bit awkward when dealing with Glozz, because there are separate annotations in different Glozz documents, the dialogue acts in the 'units' stage; and the linked units in the discourse stage. -Combining these streams has always involved a certain amount of manual lookup, -which we hope to avoid with this fusion layer. +Combining these streams has always involved a certain amount of manual +lookup, which we hope to avoid with this fusion layer. -At the time of this writing, this will have a bit of emphasis on -feature-extraction +At the time of this writing, this will have a bit of emphasis on feature +extraction. """ # pylint: disable=too-few-public-methods from __future__ import print_function import copy -import itertools as itr +import itertools from educe.annotation import (Span, Unit) from educe.stac.annotation import (is_edu, speaker, turn_id, twin_from) @@ -43,26 +43,47 @@ class Dialogue(object): """STAC Dialogue - Note that input EDUs should be sorted by span + Note that input EDUs should be sorted by span. + + Parameters + ---------- + anno : educe.stac.annotation.Unit + Glozz annotation corresponding to the dialogue ; only its + identifier is stored, currently. + + edus : list(educe.stac.annotation.Unit) + List of EDU annotations, sorted by their span. + + relations : list(educe.stac.annotation.Relation + List of relations between EDUs from the dialogue. + """ def __init__(self, anno, edus, relations): - self.edus = [FakeRootEDU] + edus self.grouping = anno.identifier() + self.edus = [FakeRootEDU] + edus + self.relations = relations # we start from 1 because 0 is for the fake root self.edu2sent = {i: e.subgrouping() for i, e in enumerate(edus, start=1)} - self.relations = relations def edu_pairs(self): - """Return all EDU pairs within this dialogue. + """Generate all EDU pairs within this dialogue. + + This includes pairs whose source is the left padding (fake root) + EDU. - NB: this is a generator + Yields + ------ + (source, target) : tuple(educe.stac.annotation.Unit) + Next candidate edge, as a pair of EDUs (source, target). """ i_edus = list(enumerate(self.edus)) _, fakeroot = i_edus[0] i_edus = i_edus[1:] # drop left padding EDU for _, edu in i_edus: yield (fakeroot, edu) + # generate all pairs of (real) EDUs + # real_pairs = [] # DEBUG for num1, edu1 in i_edus: def is_before(numedu2): 'true if we have seen the EDU already' @@ -70,9 +91,19 @@ def is_before(numedu2): num2 = numedu2[0] return num2 <= num1 # pylint: enable=cell-var-from-loop - for _, edu2 in itr.dropwhile(is_before, i_edus): + for _, edu2 in itertools.dropwhile(is_before, i_edus): yield (edu1, edu2) yield (edu2, edu1) + # DEBUG + # real_pairs.append((edu1, edu2)) + # real_pairs.append((edu2, edu1)) + # end DEBUG + # DEBUG compare list of EDU pairs from the above loop with a + # one-liner + # real_pairs_itr = sorted(itertools.permutations(self.edus[1:])) + # assert real_pairs_itr != sorted(real_pairs) + # raise ValueError("woooop") + # end DEBUG # pylint: disable=too-many-instance-attributes @@ -87,15 +118,13 @@ class EDU(Unit): It also tries to be usable as a drop-in substitute for both annotations and contexts """ - def __init__(self, doc, - discourse_anno, - unit_anno): + def __init__(self, doc, discourse_anno, unit_anno): self._doc = doc self._anno = discourse_anno self._unit_anno = unit_anno unit_anno = unit_anno or discourse_anno - unit_type = unit_anno.type if is_edu(unit_anno)\ - else discourse_anno.type + unit_type = (unit_anno.type if is_edu(unit_anno) + else discourse_anno.type) super(EDU, self).__init__(discourse_anno.local_id(), discourse_anno.text_span(), unit_type, @@ -214,31 +243,50 @@ def speaker(self): def fuse_edus(discourse_doc, unit_doc, postags): - """Return a copy of the discourse level doc, merging info - from both the discourse and units stage. + """Return a copy of the discourse level doc, merging info from both + the discourse and units stage. All EDUs will be converted to higher level EDUs. Notes ----- - * The discourse stage is primary in that we work by going over what EDUs - we find in the discourse stage and trying to enhance them with - information we find on their units-level equivalents. Sometimes (rarely - but it happens) annotations can go out of synch. EDUs missing on the - units stage will be silently ignored (we try to make do without them). + * The discourse stage is primary in that we work by going over what + EDUs we find in the discourse stage and trying to enhance them + with information we find on their units-level equivalents. + Sometimes (rarely but it happens) annotations can go out of synch. + EDUs missing on the units stage will be silently ignored (we try + to make do without them). EDUs that were introduced on the units stage but not percolated to discourse will also be ignored. - * We rely on annotation ids to match EDUs from both stages; it's up to you - to ensure that the annotations are really in synch. - - * This does not constitute a full merge of the documents. For a full merge, - you would have to bring over other annotations such as Resources, - `Preference`, `Anaphor`, `Several_resources`, taking care all the while - to ensure there are no timestamp clashes with pre-existing annotations - (it's unlikely but best be on the safe side if you ever find yourself - with automatically generated annotations, where all bets are off - time-stamp wise). + * We rely on annotation ids to match EDUs from both stages; it's up + to you to ensure that the annotations are really in synch. + + * This does not constitute a full merge of the documents. For a full + merge, you would have to bring over other annotations such as + Resources, `Preference`, `Anaphor`, `Several_resources`, taking + care all the while to ensure there are no timestamp clashes with + pre-existing annotations (it's unlikely but best be on the safe + side if you ever find yourself with automatically generated + annotations, where all bets are off time-stamp wise). + + Parameters + ---------- + discourse_doc : GlozzDocument + Document from the "discourse" stage. + + unit_doc : GlozzDocument + Document from the "units" stage. + + postags : list of Token + Sequence of educe tokens predicted by the POS tagger for this + document. + + Returns + ------- + doc : GlozzDocument + Deep copy of the discourse_doc with info from the units stage + merged in. """ doc = copy.deepcopy(discourse_doc) @@ -251,7 +299,7 @@ def fuse_edus(discourse_doc, unit_doc, postags): edu = EDU(doc, anno, unit_anno) replacements[anno] = edu - # second pass: rewrite doc so that annotations that corresponds + # second pass: rewrite doc so that annotations that correspond # to EDUs are replacement by their higher-level equivalents edus = [] for anno in annos: @@ -270,7 +318,7 @@ def fuse_edus(discourse_doc, unit_doc, postags): schema.units.append(edu) # fourth pass: flesh out the EDUs with contextual info - # now the EDUs should be work as contexts too + # now the EDUs should work as contexts too contexts = Context.for_edus(doc, postags=postags) for edu in edus: edu.fleshout(contexts[edu]) diff --git a/educe/stac/postag.py b/educe/stac/postag.py index ba89d68..33155ee 100644 --- a/educe/stac/postag.py +++ b/educe/stac/postag.py @@ -111,6 +111,21 @@ def read_tags(corpus, root_dir): educe.annotation.Standoff objects. Return a dictionary mapping 'FileId's to sets of tokens. + + Parameters + ---------- + corpus : dict(FileId, GlozzDocument) + Dictionary of documents keyed by their FileId. + + root_dir : string + Path to the directory containing the output of the POS tagger, + one file per document. + + Returns + ------- + pos_tags : dict(FileId, list(Token)) + Map from each document id to the list of tokens predicted by a + POS tagger. """ pos_tags = {} for k in corpus: diff --git a/educe/stac/rfc.py b/educe/stac/rfc.py index 2f3a746..4c827e1 100644 --- a/educe/stac/rfc.py +++ b/educe/stac/rfc.py @@ -7,7 +7,7 @@ from educe import stac from educe.stac.context import Context -from .annotation import (is_subordinating) +from .annotation import is_subordinating # pylint: disable=too-few-public-methods, no-self-use @@ -154,8 +154,9 @@ def violations(self): if not self._is_incoming_to(new_node, lnk): continue src_node, _ = graph.rel_links(lnk) - if ((last_node is None - or not self._is_on_frontier(last_node, src_node))): + if (last_node is None + or not self._is_on_frontier(last_node, src_node)): + # add link to set of violations violations.append(lnk) return violations diff --git a/educe/stac/tests.py b/educe/stac/tests.py index 99cb2f5..8f8a008 100644 --- a/educe/stac/tests.py +++ b/educe/stac/tests.py @@ -16,9 +16,9 @@ import educe.stac.graph as stac_gr from educe import annotation, corpus, stac +from educe.corpus import FileId from educe.stac import fake_graph from educe.stac.rfc import BasicRfc, ThreadedRfc -from educe.corpus import FileId from educe.stac.util.output import mk_parent_dirs @@ -39,10 +39,8 @@ def __init__(self, schema_id, members): edus = set(x.local_id() for x in members if isinstance(x, FakeEDU)) cdus = set(x.local_id() for x in members if isinstance(x, FakeCDU)) rels = set() - annotation.Schema.__init__(self, schema_id, - edus, rels, cdus, - 'Complex_discourse_unit', - {}, {}) + annotation.Schema.__init__(self, schema_id, edus, rels, cdus, + 'Complex_discourse_unit', {}, {}) class FakeDocument(annotation.Document): From 8ef1740c00ff56f4d2815b43e3a7d8c471039179 Mon Sep 17 00:00:00 2001 From: moreymat Date: Mon, 30 Jan 2017 11:33:27 +0100 Subject: [PATCH 13/44] MAINT educe.rst_dt refactoring, same_unit ; stac inquirer --- educe/rst_dt/annotation.py | 171 ++++- educe/rst_dt/corenlp.py | 8 +- educe/rst_dt/corpus.py | 45 +- educe/rst_dt/corpus_diagnostics.py | 33 +- educe/rst_dt/dep2con.py | 679 +++++++++++++------ educe/rst_dt/deptree.py | 115 +++- educe/rst_dt/document_plus.py | 96 ++- educe/rst_dt/rst_relations.py | 682 ++++++++++++++++++++ educe/rst_dt/util/cmd/check_tokenization.py | 121 ++++ educe/rst_dt/util/cmd/deptree.py | 16 +- educe/rst_dt/util/cmd/text.py | 6 +- educe/rst_dt/util/cmd/tmp.py | 5 +- educe/stac/lexicon/inquirer.py | 39 ++ 13 files changed, 1719 insertions(+), 297 deletions(-) create mode 100644 educe/rst_dt/rst_relations.py create mode 100644 educe/rst_dt/util/cmd/check_tokenization.py create mode 100644 educe/stac/lexicon/inquirer.py diff --git a/educe/rst_dt/annotation.py b/educe/rst_dt/annotation.py index bdb004e..9e4071b 100644 --- a/educe/rst_dt/annotation.py +++ b/educe/rst_dt/annotation.py @@ -18,6 +18,9 @@ import subprocess import tempfile +# nltk.draw for rendering in PS, PDF, PNG ; see RSTTree.to_ps() +from nltk.draw.tree import tree_to_treesegment +from nltk.draw.util import CanvasFrame from nltk.internals import find_binary from educe.annotation import Standoff, Span @@ -25,6 +28,11 @@ from ..internalutil import treenode +# nuclearities +NUC_N = "Nucleus" +NUC_S = "Satellite" +NUC_R = "Root" + # ghostscript parameters to generate images in different formats _GS_PARAMS = { 'png': '-sDEVICE=png16m -r90 -dTextAlphaBits=4 -dGraphicsAlphaBits=4', @@ -81,9 +89,7 @@ class EDU(Standoff): """ _SUMMARY_LEN = 20 - def __init__(self, num, span, text, - context=None, - origin=None): + def __init__(self, num, span, text, context=None, origin=None): super(EDU, self).__init__(origin) self.num = num @@ -191,8 +197,7 @@ class Node(object): A node in an `RSTTree` or `SimpleRSTTree`. """ - def __init__(self, nuclearity, edu_span, span, rel, - context=None): + def __init__(self, nuclearity, edu_span, span, rel, context=None): self.nuclearity = nuclearity "one of Nucleus, Satellite, Root" @@ -221,7 +226,7 @@ def __repr__(self): def __str__(self): return "%s %s %s" % ( "%s-%s" % self.edu_span, - self.nuclearity[0], + self.nuclearity, self.rel) def __eq__(self, other): @@ -241,13 +246,13 @@ def is_nucleus(self): can only either be nucleus/satellite or much more rarely, root. """ - return self.nuclearity == 'Nucleus' + return self.nuclearity == NUC_N def is_satellite(self): """ A node can either be a nucleus, a satellite, or a root node. """ - return self.nuclearity == 'Satellite' + return self.nuclearity == NUC_S # pylint: disable=R0904, E1103 @@ -257,13 +262,39 @@ class RSTTree(SearchableTree, Standoff): raw RST discourse treebank one. """ - def __init__(self, node, children, - origin=None): + def __init__(self, node, children, origin=None, verbose=False): """ See `educe.rst_dt.parse` to build trees from strings """ SearchableTree.__init__(self, node, children) Standoff.__init__(self, origin) + # WIP 2016-11-10 store num of head in node + if len(children) == 1 and isinstance(children[0], EDU): + # pre-terminal: head is num of terminal (EDU) + node.head = children[0].num + else: + # internal node + kids_nuclei = [i for i, kid in enumerate(children) + if kid.label().nuclearity == NUC_N] + if len(kids_nuclei) == 1: + # 1 nucleus, 1-n satellites: n mono-nuc relations + pass + elif len(kids_nuclei) == len(children): + # all children are nuclei: 1 multi-nuc relation + kid_rels = [kid.label().rel for kid in children] + if len(set(kid_rels)) > 1: + if verbose: + err_msg = ('W: More than one label in multi-nuclear' + ' relation {}'.format(children)) + print(err_msg) + else: + # corner case, should not happen + err_msg = 'E: Unknown pattern in children' + print(err_msg) + # its head is the head of its leftmost nucleus child + lnuc = children[kids_nuclei[0]] + node.head = lnuc.label().head + # end WIP head def set_origin(self, origin): """ @@ -317,8 +348,6 @@ def to_ps(self, filename): This function is used by `_repr_png_`. """ - from nltk.draw.tree import tree_to_treesegment - from nltk.draw.util import CanvasFrame _canvas_frame = CanvasFrame() # WIP customization of visual appearance # NB: conda-provided python and tk cannot access most fonts on the @@ -359,6 +388,38 @@ def edu_span(self): """ return treenode(self).edu_span + def get_spans(self, subtree_filter=None, exclude_root=False): + """Get the spans of a constituency tree. + + Each span is described by a triplet (edu_span, nuclearity, + relation). + + Parameters + ---------- + subtree_filter: function, defaults to None + Function to filter all local trees. + + exclude_root: boolean, defaults to False + If True, exclude the span of the root node. This cannot be + expressed with `subtree_filter` because the latter is limited + to properties local to each subtree in isolation. Or maybe I + just missed something. + + Returns + ------- + spans: list of tuple((int, int), str, str) + List of tuples, each describing a span with a tuple + ((edu_start, edu_end), nuclearity, relation). + """ + tnodes = [x.label() for x in self.subtrees(filter=subtree_filter) + if isinstance(x, RSTTree)] + if exclude_root: + tnodes = tnodes[1:] + # 2016-11-10 add a 4th element: head + spans = [(tn.edu_span, tn.nuclearity, tn.rel, tn.head) + for tn in tnodes] + return spans + def text(self): """ Return the text corresponding to this RST subtree. @@ -394,6 +455,14 @@ def __init__(self, node, children, origin=None): """ SearchableTree.__init__(self, node, children) Standoff.__init__(self, origin) + # WIP 2016-11-10 store num of head in node + if len(children) == 1 and isinstance(children[0], EDU): + node.head = children[0].num + else: + # head is head of the leftmost nucleus child + lnuc_idx = node.nuclearity.index('N') + node.head = children[lnuc_idx].label().head + # end WIP head def set_origin(self, origin): """ @@ -410,6 +479,38 @@ def text_span(self): def _members(self): return list(self) # children + def get_spans(self, subtree_filter=None, exclude_root=False): + """Get the spans of a constituency tree. + + Each span is described by a triplet (edu_span, nuclearity, + relation). + + Parameters + ---------- + subtree_filter: function, defaults to None + Function to filter all local trees. + + exclude_root: boolean, defaults to False + If True, exclude the span of the root node. This cannot be + expressed with `subtree_filter` because the latter is limited + to properties local to each subtree in isolation. Or maybe I + just missed something. + + Returns + ------- + spans: list of tuple((int, int), str, str) + List of tuples, each describing a span with a tuple + ((edu_start, edu_end), nuclearity, relation). + """ + tnodes = [x.label() for x in self.subtrees(filter=subtree_filter) + if isinstance(x, SimpleRSTTree)] + if exclude_root: + tnodes = tnodes[1:] + # 2016-11-10 add a 4th element: head + spans = [(tn.edu_span, tn.nuclearity, tn.rel, tn.head) + for tn in tnodes] + return spans + @classmethod def from_rst_tree(cls, tree): """ @@ -428,6 +529,7 @@ def _from_binary_rst_tree(cls, tree): if len(tree) == 1: node = copy.copy(treenode(tree)) node.rel = "leaf" + node.nuclearity = "leaf" # WIP return SimpleRSTTree(node, tree, tree.origin) else: left = tree[0] @@ -436,6 +538,9 @@ def _from_binary_rst_tree(cls, tree): lnode = treenode(left) rnode = treenode(right) node.rel = rnode.rel if rnode.is_satellite() else lnode.rel + # WIP move nuclearity up too + node.nuclearity = ''.join(x.label().nuclearity[0] for x in tree) + # end WIP kids = [cls._from_binary_rst_tree(kid) for kid in tree] return SimpleRSTTree(node, kids, tree.origin) @@ -479,7 +584,7 @@ def incorporate_nuclearity_into_label(cls, tree): return SimpleRSTTree(node, kids, tree.origin) @classmethod - def to_binary_rst_tree(cls, tree, rel='ROOT'): + def to_binary_rst_tree(cls, tree, rel='---', nuc=NUC_R): """ Build and return a binary `RSTTree` from a `SimpleRSTTree`. @@ -490,45 +595,51 @@ def to_binary_rst_tree(cls, tree, rel='ROOT'): Parameters ---------- - tree: SimpleRSTTree + tree : SimpleRSTTree SimpleRSTTree to convert - rel: string, optional - Relation that must decorate the root node of the output + rel : string, optional + Relation for the root node of the output + + nuc : string, optional + Nuclearity for the root node of the output Returns ------- - rtree: RSTTree + rtree : RSTTree The (binary) RSTTree that corresponds to the given SimpleRSTTree """ if len(tree) == 1: node = copy.copy(treenode(tree)) node.rel = rel + node.nuclearity = nuc return RSTTree(node, tree, tree.origin) else: - # left = tree[0] - # right = tree[1] node = copy.copy(treenode(tree)) - # lnode = treenode(left) - # rnode = treenode(right) # standard RST trees mark relations on the satellite # child (mononuclear relations) or on each nucleus # child (multinuclear relations) - sat_idx = [i for i, kid in enumerate(tree) - if treenode(kid).is_satellite()] + sat_idx = [i for i, nuc0 in enumerate(node.nuclearity) + if nuc0 == NUC_S[0]] if sat_idx: # mononuclear - kids = [(cls.to_binary_rst_tree(kid, rel=node.rel) - if treenode(kid).is_satellite() else - cls.to_binary_rst_tree(kid, rel='span')) - for kid in tree] + kids = [ + cls.to_binary_rst_tree( + kid, + rel=(node.rel if node.nuclearity[i] == NUC_S[0] + else 'span'), + nuc=(NUC_S if node.nuclearity[i] == NUC_S[0] + else NUC_N)) + for i, kid in enumerate(tree) + ] else: # multinuclear - kids = [cls.to_binary_rst_tree(kid, rel=node.rel) + kids = [cls.to_binary_rst_tree(kid, rel=node.rel, nuc=NUC_N) for kid in tree] - # update the rel in the current node + # update the rel and nuc in the current node node.rel = rel + node.nuclearity = nuc return RSTTree(node, kids, tree.origin) @@ -546,7 +657,7 @@ def builder(right, left): rnode = treenode(right) edu_span = (lnode.edu_span[0], rnode.edu_span[1]) span = lnode.span.merge(rnode.span) - newnode = Node('Nucleus', edu_span, span, rel) + newnode = Node(NUC_N, edu_span, span, rel) return RSTTree(newnode, [left, right], origin=left.origin) return functools.reduce(builder, kids[::-1]) diff --git a/educe/rst_dt/corenlp.py b/educe/rst_dt/corenlp.py index be321bf..bfcd78b 100644 --- a/educe/rst_dt/corenlp.py +++ b/educe/rst_dt/corenlp.py @@ -14,11 +14,11 @@ import nltk.tree -from educe.external.corenlp import (CoreNlpToken, CoreNlpDocument) -from educe.external.coref import (Chain, Mention) -from educe.external.parser import (ConstituencyTree, DependencyTree) +from educe.external.corenlp import CoreNlpToken, CoreNlpDocument +from educe.external.coref import Chain, Mention +from educe.external.parser import ConstituencyTree, DependencyTree from educe.external.stanford_xml_reader import PreprocessingSource -from educe.ptb.annotation import (transform_tree, strip_subcategory) +from educe.ptb.annotation import transform_tree, strip_subcategory from educe.ptb.head_finder import find_lexical_heads diff --git a/educe/rst_dt/corpus.py b/educe/rst_dt/corpus.py index 43f6971..faf0cfd 100644 --- a/educe/rst_dt/corpus.py +++ b/educe/rst_dt/corpus.py @@ -5,9 +5,11 @@ Corpus management (re-exported by educe.rst_dt) """ -from glob import glob import os import sys +from glob import glob +from os.path import dirname +from os.path import join from nltk import Tree @@ -17,7 +19,7 @@ import educe.util import educe.corpus from .document_plus import DocumentPlus -from .annotation import SimpleRSTTree, _binarize +from .annotation import SimpleRSTTree from .deptree import RstDepTree from .pseudo_relations import rewrite_pseudo_rels @@ -116,7 +118,7 @@ class RstDtParser(object): If True, relation labels are converted to their coarse-grained equivalent. - nary_conv : string, optional + nary_enc : string, optional Conversion method from constituency to dependency tree, for n-ary spans, n > 2, whose kids are all nuclei: 'tree' picks the leftmost nucleus as the head of all the others @@ -141,7 +143,7 @@ class RstDtParser(object): def __init__(self, corpus_dir, args, coarse_rels=False, fix_pseudo_rels=False, - nary_conv='chain', + nary_enc='chain', nuc_in_label=False, exclude_file_docs=False): self.reader = Reader(corpus_dir) @@ -160,10 +162,10 @@ def __init__(self, corpus_dir, args, coarse_rels=False, else: self.rel_conv = None # how to convert n-ary spans - self.nary_conv = nary_conv - if nary_conv not in ['chain', 'tree']: + self.nary_enc = nary_enc + if nary_enc not in ['chain', 'tree']: err_msg = 'Unknown conversion for n-ary spans: {}' - raise ValueError(err_msg.format(nary_conv)) + raise ValueError(err_msg.format(nary_enc)) # whether nuclearity should be part of the label self.nuc_in_label = nuc_in_label @@ -216,14 +218,8 @@ def decode(self, doc_key): # end TO DEPRECATE # convert to dep tree - # WIP - if self.nary_conv == 'chain': - # legacy mode, through SimpleRSTTree - # deptree = RstDepTree.from_simple_rst_tree(rsttree) - # modern mode, directly from a binarized RSTTree - deptree = RstDepTree.from_rst_tree(_binarize(orig_rsttree)) - else: # tree conversion - deptree = RstDepTree.from_rst_tree(orig_rsttree) + deptree = RstDepTree.from_rst_tree(orig_rsttree, + nary_enc=self.nary_enc) # end WIP doc.deptree = deptree @@ -278,3 +274,22 @@ def convert_tree(self, rst_tree): # replace old rel with new rel node.rel = conv_lbl(node.rel) return rst_tree + + def convert_dtree(self, dtree): + """Change relation labels in an RstDepTree using the label mapping. + + See attribute `self.convert_label`. + + Parameters + ---------- + dtree : RstDepTree + RST dtree + + Returns + ------- + dtree : RstDepTree + RST dtree with mapped labels. + """ + conv_lbl = self.convert_label + dtree.labels[1:] = [conv_lbl(x) for x in dtree.labels[1:]] + return dtree diff --git a/educe/rst_dt/corpus_diagnostics.py b/educe/rst_dt/corpus_diagnostics.py index 0237a66..bc04a99 100644 --- a/educe/rst_dt/corpus_diagnostics.py +++ b/educe/rst_dt/corpus_diagnostics.py @@ -359,6 +359,11 @@ def load_spans(coarse_rtree_ref): kid_nucs = ''.join(('N' if kid.label().nuclearity == 'Nucleus' else 'S') for kid in node) + # 2016-09-15 quick n dirty check for n-ary nodes with >1 S + cnt_nucs = Counter(kid_nucs) + if cnt_nucs['S'] > 1: + print(node.origin.doc, node_label.edu_span, kid_nucs) + # end n-ary nodes with >1 S else: # missing values, for pre-terminals (pre-EDUs) kid_rels = None kid_nucs = None @@ -811,6 +816,12 @@ def gather_leaky_stats(): """ nodes_train, rels_train, edus_train, sents_train, paras_train = load_corpus_as_dataframe_new(selection='train', binarize=False) + + # 2016-09-14 + print('test') + nodes_test, rels_test, edus_test, sents_test, paras_test = load_corpus_as_dataframe_new(selection='test', binarize=False) + # end 2016-09-14 + # print(rels_train) # as of version 0.17, pandas handles missing boolean values by degrading # column type to object, which makes boolean selection return true for @@ -862,8 +873,8 @@ def gather_leaky_stats(): # new hypothesis: 75% of leaky sentences can be split so that their EDUs # + the neighboring sentences form complete spans if False: - print(leaky_sents[['parent_sent_len', 'parent_sent_dist']].describe( - percentiles=[.1, .2, .3, .4, .5, .6, .7, .8, .9])) + print(leaky_sents[['parent_sent_len', 'parent_sent_dist']] + .describe(percentiles=[.1, .2, .3, .4, .5, .6, .7, .8, .9])) print(leaky_sents[(leaky_sents['parent_sent_dist'] == 1)].describe()) print(leaky_sents[(leaky_sents['parent_sent_dist'] == 1) & (leaky_sents['parent_sent_len'] > 2)]) @@ -872,17 +883,20 @@ def gather_leaky_stats(): # taxonomy of leaky sentences print(complex_sents.groupby('leaky_type')['edu_len'] .describe().unstack()) + # absolute value counts print(complex_sents.groupby('edu_len')['leaky_type'] - .value_counts(normalize=False).unstack()) # absolute value counts + .value_counts(normalize=False).unstack()) + # normalized value counts print(complex_sents.groupby('edu_len')['leaky_type'] - .value_counts(normalize=True).unstack()) # normalized value counts + .value_counts(normalize=True).unstack()) # WIP straddling relations strad_rels_df = pd.DataFrame(strad_rels_rows) print() print(strad_rels_df['sent_id'].describe()['count']) - print(strad_rels_df.groupby(['kid_rels']).describe()['sent_id'] - .unstack().sort_values('count', ascending=False)) + print(strad_rels_df.groupby(['kid_rels']) + .describe()['sent_id'].unstack() + .sort_values('count', ascending=False)) # compare to distribution of intra/inter relations print() print(rels_train[rels_train['edu_len'] > 1] @@ -928,16 +942,15 @@ def gather_leaky_stats(): # compare leaky with non-leaky complex paragraphss: EDU length print('EDU span length of leaky vs non-leaky complex paragraphs') - print(complex_paras.groupby('leaky')['edu_span_len'].describe() - .unstack()) + print(complex_paras.groupby('leaky')['edu_span_len'] + .describe().unstack()) print() # for each leaky paragraph, number of paragraphs included in the # smallest RST node that fully covers the leaky paragraph if False: print(leaky_paras[['parent_span_para_len', 'parent_span_para_dist']] - .describe( - percentiles=[.1, .2, .3, .4, .5, .6, .7, .8, .9])) + .describe(percentiles=[.1, .2, .3, .4, .5, .6, .7, .8, .9])) print(leaky_paras[(leaky_paras['parent_span_para_dist'] == 1)] .describe()) print(leaky_paras[(leaky_paras['parent_span_para_dist'] == 1) & diff --git a/educe/rst_dt/dep2con.py b/educe/rst_dt/dep2con.py index 4ca2373..65ebeea 100644 --- a/educe/rst_dt/dep2con.py +++ b/educe/rst_dt/dep2con.py @@ -8,12 +8,14 @@ underlying multinuclear relation """ -from collections import namedtuple +from collections import defaultdict, namedtuple import itertools -from .annotation import SimpleRSTTree, Node -from .deptree import RstDtException, NUC_N, NUC_S, NUC_R -from ..internalutil import treenode +from educe.annotation import Span +from educe.internalutil import treenode +from educe.rst_dt.annotation import (NUC_N, NUC_S, NUC_R, Node, RSTTree, + SimpleRSTTree) +from educe.rst_dt.deptree import RstDtException class DummyNuclearityClassifier(object): @@ -29,14 +31,21 @@ class DummyNuclearityClassifier(object): set, mononuclear otherwise. * "most_frequent_by_rel": predicts the most frequent nuclearity for the given relation label in the training set. + * "constant": always predicts a constant label provided by the + user. + + constant : str + The explicit constant as predicted by the "constant" strategy. + This parameter is useful only for the "constant" strategy. TODO ---- complete after `sklearn.dummy.DummyClassifier` """ - def __init__(self, strategy="unamb_else_most_frequent"): + def __init__(self, strategy="unamb_else_most_frequent", constant=None): self.strategy = strategy + self.constant = constant def fit(self, X, y): """Fit the dummy classifier. @@ -52,12 +61,22 @@ def fit(self, X, y): y: array-like, shape = [n_samples] Target nuclearity array for each EDU of each RstDepTree. """ - if self.strategy not in ["unamb_else_most_frequent", - "most_frequent_by_rel"]: + if self.strategy not in ("unamb_else_most_frequent", + "most_frequent_by_rel", + "constant"): raise ValueError("Unknown strategy type.") + if (self.strategy == "constant" and + self.constant not in (NUC_N, NUC_S)): + # ensure that the constant value provided is acceptable + raise ValueError("The constant target value must be " + "{} or {}".format(NUC_N, NUC_S)) + # special processing: ROOT is considered multinuclear - multinuc_lbls = ['ROOT'] + # 2016-12-06 I'm unsure what form "root" should have at this + # point, so all three possible values are currently included + # but we should trim this list down (MM) + multinuc_lbls = ['ROOT', 'root', '---'] if self.strategy == "unamb_else_most_frequent": # FIXME automatically get these from the training set multinuc_lbls.extend(['joint', 'same-unit', 'textual']) @@ -87,12 +106,17 @@ def predict(self, X): """ y = [] for dtree in X: - # NB: we condition multinuclear relations on (i > head) - yi = [(NUC_N if (i > head and rel in self.multinuc_lbls_) - else NUC_S) - for i, (head, rel) - in enumerate(itertools.izip(dtree.heads, dtree.labels))] - y.append(yi) + if self.strategy == "constant": + yi = [self.constant for rel in dtree.labels] + y.append(yi) + else: + # FIXME NUC_R for the root? + # NB: we condition multinuclear relations on (i > head) + yi = [(NUC_N if (i > head and rel in self.multinuc_lbls_) + else NUC_S) + for i, (head, rel) + in enumerate(itertools.izip(dtree.heads, dtree.labels))] + y.append(yi) return y @@ -100,9 +124,9 @@ def predict(self, X): class InsideOutAttachmentRanker(object): """Rank modifiers, from the inside out on either side. - Given a dependency tree node and its children, return the list - of children but *stably* sorted to fulfill an inside-out - traversal on either side. + Given a dependency tree node and its children, return an order on + the list of children, that should fulfill an inside-out traversal + on either side. Let's back up a little bit for some background on this criterion. We assume that dependency tree nodes can be characterised and @@ -127,45 +151,35 @@ class InsideOutAttachmentRanker(object): - 'lrlrlr': alternating directions, left first, - 'rlrlrl': alternating directions, right first. - TRICKY SORTING! The current implementation of the 'id' strategy was - reached through a bit of trial and error, so you may want to modify - with caution. - - Most of the trickiness in the 'id' strategy is in making this a - *stable* sort, ie. we want to preserve the original order of - the targets as much as possible because this allows us to have - round trip conversions from RST to DT and back. This essentially - means preserving the interleaving of left/right nodes. The basic - logic in the implementation is to traverse our target list as - a series of LEFT or RIGHT slots, filling the slots in an - inside-out order. So for example, if we saw a target list - `l3 r1 r3 l2 l1 r2`, we would treat it as the slots `L R R L L R` - and fill them out as `l1 r1 r2 l2 l3 r3` """ - def __init__(self, strategy='id', prioritize_same_unit=False): + def __init__(self, strategy='id', prioritize_same_unit=False, + order='weak'): if strategy not in ['id', 'lllrrr', 'rrrlll', 'lrlrlr', 'rlrlrl', 'closest-lr', 'closest-rl', 'closest-intra-lr-inter-lr', 'closest-intra-rl-inter-rl', - 'closest-intra-rl-inter-lr']: + 'closest-intra-rl-inter-lr', + 'sdist-edist-lr', 'sdist-edist-rl']: raise ValueError('Unknown transformation strategy ' '{stg}'.format(stg=strategy)) self.strategy = strategy self.prioritize_same_unit = prioritize_same_unit + if order not in ['weak', 'strict']: + raise ValueError("Order must be one of {'weak', 'strict'}") + self.order = order def fit(self, X, y): """Here a no-op.""" return self def predict(self, X): - """Produce a ranking. + """Predict order between modifiers of the same head. - This keeps the alternation of left and right modifiers as it is - in `targets` but possibly re-orders dependents on either side to - guarantee inside-out traversal. + The predicted order should guarantee inside-out traversal on + either side of a head. Parameters ---------- @@ -186,6 +200,10 @@ def predict(self, X): """ strategy = self.strategy + if strategy == 'id': # radical oracle: use gold rank as is + # we know the true order, it is stored in dtree.ranks + return [dtree.ranks for dtree in X] + dt_ranks = [] for dtree in X: # for each RstDepTree, the result will be an array of ranks @@ -193,175 +211,299 @@ def predict(self, X): unique_heads = set(dtree.heads[1:]) # exclude head of fake root for head in unique_heads: + rank_idx = 1 # init rank + targets = [i for i, hd in enumerate(dtree.heads) if hd == head] - # what follows should be well-tested code - sorted_nodes = sorted( - [head] + targets, - key=lambda x: dtree.edus[x].span.char_start) - centre = sorted_nodes.index(head) + if self.prioritize_same_unit: + # gobble everything between the head and the rightmost + # "same-unit" + # FIXME weak order: fragments of an n-ary "same-unit" + # should get the same order + same_unit_tgts = [tgt for tgt in targets + if dtree.labels[tgt] == 'same-unit'] + if same_unit_tgts: + # take first all dependents between the head + # and the rightmost same-unit + last_same_unit_tgt = same_unit_tgts[-1] + priority_tgts = [tgt for tgt in targets + if (tgt > head and + tgt <= last_same_unit_tgt)] + for tgt in priority_tgts: + ranks[tgt] = rank_idx + rank_idx += 1 + # remove them from the remaining targets + targets = [tgt for tgt in targets + if tgt not in priority_tgts] + # elements to the left and right of the node respectively # these are stacks (outside ... inside) - left = sorted_nodes[:centre] - right = list(reversed(sorted_nodes[centre+1:])) - - # special strategy: 'id' (we know the true targets) - if strategy == 'id': - result = [left.pop() if (tree in left) else right.pop() - for tree in targets] + left = [i for i in targets if i < head] + right = [i for i in targets if i > head] + right = list(reversed(right)) + + if strategy in ['lllrrr', 'rrrlll']: + # one side then the other + # FIXME weak order + sides = ([left, right] if strategy == 'lllrrr' + else [right, left]) + for side in sides: + for tgt in side: + ranks[tgt] = rank_idx + rank_idx += 1 + + elif strategy in ['lrlrlr', 'rlrlrl']: + # alternating sides + sides = ([left, right] if strategy == 'lrlrlr' + else [right, left]) + while any(sides): + for side in sides: + if not side: + continue + dep_cur = side.pop() + ranks[dep_cur] = rank_idx + # weakly-ordered: consecutive nuclei with + # same label are assumed to be part of a + # multinuclear relation => same rank + lbl_cur = dtree.labels[dep_cur] + nuc_cur = dtree.nucs[dep_cur] + if self.order == 'weak' and nuc_cur == NUC_N: + while (side + and dtree.labels[side[-1]] == lbl_cur + and dtree.nucs[side[-1]] == nuc_cur): + # give same rank + dep_cur = side.pop() + ranks[dep_cur] = rank_idx + # increment rank + rank_idx += 1 + + elif strategy in ['closest-lr', 'closest-rl']: + # take closest dependents first, break ties using the + # side: lr to take left over right, rl to take right + # over left + sides = ([left, right] if strategy == 'closest-lr' + else [right, left]) + while any(sides): + if left and right: + dist_left = abs(left[-1] - head) + dist_right = abs(right[-1] - head) + if dist_left == dist_right: + side = sides[0] + elif dist_left < dist_right: + side = left + else: + side = right + else: # one side is empty + side = left if left else right + # same code as lrlrlr/rlrlrl above ; make into + # helper function? + dep_cur = side.pop() + ranks[dep_cur] = rank_idx + # weakly-ordered: consecutive nuclei with + # same label are assumed to be part of a + # multinuclear relation => same rank + lbl_cur = dtree.labels[dep_cur] + nuc_cur = dtree.nucs[dep_cur] + if self.order == 'weak' and nuc_cur == NUC_N: + while (side + and dtree.labels[side[-1]] == lbl_cur + and dtree.nucs[side[-1]] == nuc_cur): + # give same rank + dep_cur = side.pop() + ranks[dep_cur] = rank_idx + # increment rank + rank_idx += 1 - # strategies that try to guess the order of attachment else: - result = [] - - if self.prioritize_same_unit: - # gobble everything between the head and the rightmost - # "same-unit" - same_unit_tgts = [tgt for tgt in targets - if dtree.labels[tgt] == 'same-unit'] - if same_unit_tgts: - # take first all dependents between the head - # and the rightmost same-unit - last_same_unit_tgt = same_unit_tgts[-1] - priority_tgts = [tgt for tgt in targets - if (tgt > head and - tgt <= last_same_unit_tgt)] - # prepend to the result - result.extend(priority_tgts) - # remove from the remaining targets - targets = [tgt for tgt in targets - if tgt not in priority_tgts] - - if strategy == 'lllrrr': - result.extend(left.pop() if left else right.pop() - for _ in targets) - - elif strategy == 'rrrlll': - result.extend(right.pop() if right else left.pop() - for _ in targets) - - elif strategy == 'lrlrlr': - # reverse lists of left and right modifiers - # these are queues (inside ... outside) - left_io = list(reversed(left)) - right_io = list(reversed(right)) - lrlrlr_gen = itertools.chain.from_iterable( - itertools.izip_longest(left_io, right_io)) - result.extend(x for x in lrlrlr_gen - if x is not None) - - elif strategy == 'rlrlrl': - # reverse lists of left and right modifiers - # these are queues (inside ... outside) - left_io = list(reversed(left)) - right_io = list(reversed(right)) - rlrlrl_gen = itertools.chain.from_iterable( - itertools.izip_longest(right_io, left_io)) - result.extend(x for x in rlrlrl_gen - if x is not None) - - elif strategy == 'closest-rl': - # take closest dependents first, take right over - # left to break ties - sort_key = lambda e: (abs(e - head), - 1 if e > head else 2) - result.extend(sorted(targets, key=sort_key)) - - elif strategy == 'closest-lr': - # take closest dependents first, take left over - # right to break ties - sort_key = lambda e: (abs(e - head), - 2 if e > head else 1) - result.extend(sorted(targets, key=sort_key)) - # strategies that depend on intra/inter-sentential info - # NB: the way sentential info is stored is expected to - # change at some point + # NB: the way sentential info is stored is a dirty hack ; + # this should be fixed at some point + if not hasattr(dtree, 'sent_idx'): + raise ValueError(('Strategy {stg} depends on ' + 'sentential information which is ' + 'missing here' + '').format(stg=strategy)) + + # sent_idx for all EDUs that need to be locally + # ranked (+ their head) + # FIXME write a clean imputation procedure + # that is global to all EDUs in the document + loc_edus = sorted(targets + [head]) + sent_idc = [dtree.sent_idx[x] for x in loc_edus + if dtree.sent_idx[x] is not None] + if len(sent_idc) != len(loc_edus): + # missing sent_idx => (pseudo-)imputation ; + # this is a very local, and dirty, workaround + # * left dependents + head + sent_idc_left = [] + sent_idx_cur = min(sent_idc) if sent_idc else 0 + for x in loc_edus: + if x > head: + break + sent_idx_x = dtree.sent_idx[x] + if sent_idx_x is not None: + sent_idx_cur = sent_idx_x + sent_idc_left.append(sent_idx_cur) + # * right dependents + sent_idc_right = [] + sent_idx_cur = max(sent_idc) if sent_idc else 0 + for x in reversed(targets): + if x <= head: + break + sent_idx_x = dtree.sent_idx[x] + if sent_idx_x is not None: + sent_idx_cur = sent_idx_x + sent_idc_right.append(sent_idx_cur) + # * replace sent_idc with the result of the + # pseudo-imputation + sent_idc = (sent_idc_left + + list(reversed(sent_idc_right))) + # build this into a dict + sent_idx_loc = {e: s_idx for e, s_idx + in zip(loc_edus, sent_idc)} + + # intra/inter strategies + if strategy in ['closest-intra-rl-inter-lr', + 'closest-intra-rl-inter-rl', + 'closest-intra-lr-inter-lr', + 'closest-intra-lr-inter-rl']: + # best: closest-intra-rl-inter-lr (2016-07-??) + # current: closest-intra-rl-inter-rl (2016-09-13) + + # intra + left_intra = [tgt for tgt in left + if sent_idx_loc[tgt] == sent_idx_loc[head]] + right_intra = [tgt for tgt in right + if sent_idx_loc[tgt] == sent_idx_loc[head]] + sides = ([right_intra, left_intra] + if strategy in ['closest-intra-rl-inter-lr', + 'closest-intra-rl-inter-rl'] + else [left_intra, right_intra]) + # same code as 'closest-*' above + while any(sides): + if left_intra and right_intra: + dist_left = abs(left_intra[-1] - head) + dist_right = abs(right_intra[-1] - head) + if dist_left == dist_right: + side = sides[0] + elif dist_left < dist_right: + side = left_intra + else: + side = right_intra + else: # one side is empty + side = (left_intra if left_intra + else right_intra) + # same code as lrlrlr/rlrlrl above ; make into + # helper function? + dep_cur = side.pop() + ranks[dep_cur] = rank_idx + # weakly-ordered: consecutive nuclei with + # same label are assumed to be part of a + # multinuclear relation => same rank + lbl_cur = dtree.labels[dep_cur] + nuc_cur = dtree.nucs[dep_cur] + if self.order == 'weak' and nuc_cur == NUC_N: + while (side + and dtree.labels[side[-1]] == lbl_cur + and dtree.nucs[side[-1]] == nuc_cur): + # give same rank + dep_cur = side.pop() + ranks[dep_cur] = rank_idx + # increment rank + rank_idx += 1 + + # inter + left_inter = [tgt for tgt in left + if sent_idx_loc[tgt] != sent_idx_loc[head]] + right_inter = [tgt for tgt in right + if sent_idx_loc[tgt] != sent_idx_loc[head]] + sides = ([right_inter, left_inter] + if strategy in ['closest-intra-lr-inter-rl', + 'closest-intra-rl-inter-rl'] + else [left_inter, right_inter]) + # same code as 'closest-*' above + while any(sides): + if left_inter and right_inter: + dist_left = abs(left_inter[-1] - head) + dist_right = abs(right_inter[-1] - head) + if dist_left == dist_right: + side = sides[0] + elif dist_left < dist_right: + side = left_inter + else: + side = right_inter + else: # one side is empty + side = (left_inter if left_inter + else right_inter) + # same code as lrlrlr/rlrlrl above ; make into + # helper function? + dep_cur = side.pop() + ranks[dep_cur] = rank_idx + # weakly-ordered: consecutive nuclei with + # same label are assumed to be part of a + # multinuclear relation => same rank + lbl_cur = dtree.labels[dep_cur] + nuc_cur = dtree.nucs[dep_cur] + if self.order == 'weak' and nuc_cur == NUC_N: + while (side + and dtree.labels[side[-1]] == lbl_cur + and dtree.nucs[side[-1]] == nuc_cur): + # give same rank + dep_cur = side.pop() + ranks[dep_cur] = rank_idx + # increment rank + rank_idx += 1 + + elif strategy in ['sdist-edist-lr', 'sdist-edist-rl']: + # used 2016-09-13: sdist-edist-rl + # distance in sentences, then in EDUs, then pick + # side to break ties + sides = ([left, right] if strategy == 'sdist-edist-lr' + else [right, left]) + while any(sides): + if left and right: + # distances: in sentences, EDUs + # * next candidate on the left + sdist_left = abs(sent_idx_loc[left[-1]] + - sent_idx_loc[head]) + edist_left = abs(left[-1] - head) + dist_left = (sdist_left, edist_left) + # * next candidate on the right + sdist_right = abs(sent_idx_loc[right[-1]] + - sent_idx_loc[head]) + edist_right = abs(right[-1] - head) + dist_right = (sdist_right, edist_right) + # * compare + if dist_left == dist_right: + side = sides[0] + elif dist_left < dist_right: + side = left + else: + side = right + else: # one side is empty + side = left if left else right + # same code as lrlrlr/rlrlrl above ; make into + # helper function? + dep_cur = side.pop() + ranks[dep_cur] = rank_idx + # weakly-ordered: consecutive nuclei with + # same label are assumed to be part of a + # multinuclear relation => same rank + lbl_cur = dtree.labels[dep_cur] + nuc_cur = dtree.nucs[dep_cur] + if self.order == 'weak' and nuc_cur == NUC_N: + while (side + and dtree.labels[side[-1]] == lbl_cur + and dtree.nucs[side[-1]] == nuc_cur): + # give same rank + dep_cur = side.pop() + ranks[dep_cur] = rank_idx + # increment rank + rank_idx += 1 + else: - if not hasattr(dtree, 'sent_idx'): - raise ValueError(('Strategy {stg} depends on ' - 'sentential information which is' - ' missing here' - '').format(stg=strategy)) - - if strategy == 'closest-intra-rl-inter-lr': - # current best - # take closest dependents first, take right over - # left to break ties - sort_key = lambda e: ( - 1 if dtree.sent_idx[e] == dtree.sent_idx[head] else 2, - abs(e - head), - 1 if ((e > head and - dtree.sent_idx[e] == dtree.sent_idx[head]) or - (e < head and - dtree.sent_idx[e] != dtree.sent_idx[head])) else 2 - ) - result.extend(sorted(targets, key=sort_key)) - - elif strategy == 'closest-intra-rl-inter-rl': # current used - # sent_idx for all EDUs that need to be locally - # ranked (+ their head) - loc_edus = sorted(targets + [head]) - sent_idc = [dtree.sent_idx[x] for x in loc_edus - if dtree.sent_idx[x] is not None] - if len(sent_idc) != len(loc_edus): - # missing sent_idx => (pseudo-)imputation ; - # this is a very local, and dirty, workaround - # * left dependents + head - sent_idc_left = [] - sent_idx_cur = min(sent_idc) if sent_idc else 0 - for x in loc_edus: - if x > head: - break - sent_idx_x = dtree.sent_idx[x] - if sent_idx_x is not None: - sent_idx_cur = sent_idx_x - sent_idc_left.append(sent_idx_cur) - # * right dependents - sent_idc_right = [] - sent_idx_cur = max(sent_idc) if sent_idc else 0 - for x in reversed(targets): - if x <= head: - break - sent_idx_x = dtree.sent_idx[x] - if sent_idx_x is not None: - sent_idx_cur = sent_idx_x - sent_idc_right.append(sent_idx_cur) - # * replace sent_idc with the result of the - # pseudo-imputation - sent_idc = (sent_idc_left + - list(reversed(sent_idc_right))) - # build this into a dict - sent_idx_loc = {e: s_idx for e, s_idx - in zip(loc_edus, sent_idc)} - - # take closest dependents first, break ties by - # choosing right first, then left - sort_key = lambda e: ( - abs(sent_idx_loc[e] - sent_idx_loc[head]), - abs(e - head), - 1 if e > head else 2 - ) - result.extend(sorted(targets, key=sort_key)) - - elif strategy == 'closest-intra-lr-inter-lr': - # take closest dependents first, take left over - # right to break ties - sort_key = lambda e: ( - 1 if dtree.sent_idx[e] == dtree.sent_idx[head] else 2, - abs(e - head), - 2 if e > head else 1 - ) - result.extend(sorted(targets, key=sort_key)) - - else: - raise RstDtException('Unknown transformation strategy' - ' {stg}'.format(stg=strategy)) - - # update array of ranks for this deptree - # ranks are 1-based - for i, tgt in enumerate(result, start=1): - ranks[tgt] = i + raise RstDtException('Unknown transformation strategy' + ' {stg}'.format(stg=strategy)) dt_ranks.append(ranks) return dt_ranks @@ -584,3 +726,144 @@ class TreeParts(namedtuple("TreeParts_", "edu edu_span span rel kids")): """ pass # pylint: enable=R0903, W0232 + + +def deptree_to_rst_tree(dtree): + """Create an RSTTree from an RstDepTree. + + Parameters + ---------- + dtree: RstDepTree + RST dependency tree, i.e. an ordered dtree. + + Returns + ------- + ctree: RSTTree + RST constituency tree that corresponds to the dtree. + """ + heads = dtree.heads + ranks = dtree.ranks + origin = dtree.origin + + # gov -> (rank -> [deps]) + ranked_deps = defaultdict(lambda: defaultdict(list)) + for dep, (gov, rnk) in enumerate(zip(heads[1:], ranks[1:]), start=1): + ranked_deps[gov][rnk].append(dep) + + # store pointers to substructures as they are built + subtrees = [None for x in dtree.edus] + + # compute height of each governor in the dtree + heights = [0 for x in dtree.edus] + while True: + old_heights = tuple(heights) + for i, hd in enumerate(dtree.heads[1:], start=1): + heights[hd] = max(heights[hd], heights[i] + 1) + if tuple(heights) == old_heights: + # fixpoint reached + break + # group nodes by their height in the dtree + govs_by_height = defaultdict(list) + for i, height in enumerate(heights): + govs_by_height[height].append(i) + + # bottom-up traversal of the dtree: create sub-ctrees + # * create leaves of the RST ctree: initialize them with the + # label and nuclearity from the dtree + for i in range(1, len(dtree.edus)): + node = Node(dtree.nucs[i], (i, i), dtree.edus[i].span, + dtree.labels[i], context=None) # TODO context? + children = [dtree.edus[i]] # WIP + subtrees[i] = RSTTree(node, children, origin=origin) + + # * create internal nodes: for each governor, create one projection + # per rank of dependents ; each time a projection node is created, + # we use the set of dependencies to overwrite the nuc and label of + # its children + for height in range(1, max(heights)): # leave fake root out, see below + nodes = govs_by_height[height] + for gov in nodes: + # max_rnk = max(ranked_deps[gov].keys()) + for rnk, deps in sorted(ranked_deps[gov].items()): + # overwrite the nuc and lbl of the head node, using the + # dependencies of this rank + dep_nucs = [dtree.nucs[x] for x in deps] + dep_lbls = [dtree.labels[x] for x in deps] + if all(x == NUC_N for x in dep_nucs): + # all nuclei must have the same label, to denote + # a unique multinuclear relation + assert len(set(dep_lbls)) == 1 + gov_lbl = dep_lbls[0] + elif all(x == NUC_S for x in dep_nucs): + gov_lbl = 'span' + else: + raise ValueError('Deps have different nuclearities') + gov_node = subtrees[gov].label() + gov_node.nuclearity = NUC_N + gov_node.rel = gov_lbl + # create one projection node for the head + the dependencies + # of this rank + proj_lbl = dtree.labels[gov] + proj_nuc = dtree.nucs[gov] + proj_children = [subtrees[x] for x in sorted([gov] + deps)] + proj_edu_span = (proj_children[0].label().edu_span[0], + proj_children[-1].label().edu_span[1]) + proj_txt_span = Span(proj_children[0].label().span.char_start, + proj_children[-1].label().span.char_end) + proj_node = Node(proj_nuc, proj_edu_span, proj_txt_span, + proj_lbl, context=None) # TODO context? + subtrees[gov] = RSTTree(proj_node, proj_children, + origin=origin) + # create top node and whole tree + # this is where we handle the fake root + gov = 0 + proj_lbl = '---' # 2016-12-02: switch from "ROOT" to "---" so that + # _pred and _true have the same labels for their root nodes + proj_nuc = NUC_R + if (ranked_deps[gov].keys() == [1] + and len(ranked_deps[gov][1]) == 1): + # unique real root => use its projection as the root of the ctree + unique_real_root = ranked_deps[gov][1][0] + # proj = subtrees[unique_real_root].label() + proj_node.nuclearity = proj_nuc + proj_node.rel = proj_lbl + subtrees[0] = subtrees[unique_real_root] + else: + # > 1 real root: create projections until we span all + # 2016-09-14 disable support for >1 real root + raise ValueError("Fragile: RSTTree from dtree with >1 real root") + # + # max_rnk = max(ranked_deps[gov].keys()) + for rnk, deps in sorted(ranked_deps[gov].items()): + # overwrite the nuc and lbl of the head node, using the + # dependencies of this rank + dep_nucs = [dtree.nucs[x] for x in deps] + dep_lbls = [dtree.labels[x] for x in deps] + if all(x == NUC_N for x in dep_nucs): + # all nuclei must have the same label, to denote + # a unique multinuclear relation + assert len(set(dep_lbls)) == 1 + gov_lbl = dep_lbls[0] + elif all(x == NUC_S for x in dep_nucs): + gov_lbl = 'span' + else: + raise ValueError('Deps have different nuclearities') + gov_node = subtrees[gov].label() + gov_node.nuclearity = NUC_N + gov_node.rel = gov_lbl + # create one projection node for the head + the dependencies + # of this rank + proj_lbl = dtree.labels[gov] + proj_nuc = dtree.nucs[gov] + proj_children = [subtrees[x] for x in sorted([gov] + deps)] + proj_edu_span = (proj_children[0].label().edu_span[0], + proj_children[-1].label().edu_span[1]) + proj_txt_span = Span(proj_children[0].label().span.char_start, + proj_children[-1].label().span.char_end) + proj_node = Node(proj_nuc, proj_edu_span, proj_txt_span, + proj_lbl, context=None) # TODO context? + subtrees[gov] = RSTTree(proj_node, proj_children, + origin=origin) + # final RST ctree + rst_tree = subtrees[0] + return rst_tree diff --git a/educe/rst_dt/deptree.py b/educe/rst_dt/deptree.py index 4d589c4..ab1347e 100644 --- a/educe/rst_dt/deptree.py +++ b/educe/rst_dt/deptree.py @@ -12,15 +12,10 @@ import numpy as np -from .annotation import EDU +from .annotation import EDU, _binarize, NUC_N, NUC_S # , NUC_R from ..internalutil import treenode -NUC_N = "Nucleus" -NUC_S = "Satellite" -NUC_R = "Root" - - class RstDtException(Exception): """ Exceptions related to conversion between RST and DT trees. @@ -41,6 +36,46 @@ def __init__(self, msg): DEFAULT_RANK = 0 +# helper function for conversion from binary to nary relations +def binary_to_nary(nary_enc, pairs): + """Retrieve nary relations from a set of binary relations. + + Parameters + ---------- + nary_enc: one of {"chain", "tree"} + Encoding from n-ary to binary relations. + pairs: iterable of pairs of identifier (ex: integer, string...) + Binary relations. + + Return + ------ + nary_rels: list of tuples of identifiers + Nary relations. + """ + nary_rels = [] + open_ends = [] # companion to nary_rels: open end + for gov_idx, dep_idx in pairs: + try: + # search for an existing fragmented EDU this same-unit + # could belong to + open_frag = open_ends.index(gov_idx) + except ValueError: + # start a new fragmented EDU + nary_rels.append([gov_idx, dep_idx]) + if nary_enc == 'chain': + open_ends.append(dep_idx) + else: # 'tree' + open_ends.append(gov_idx) + else: + # append dep_idx to an existing fragmented EDU + nary_rels[open_frag].append(dep_idx) + # NB: if "tree", no need to update the open end + if nary_enc == 'chain': + open_ends[open_frag] = dep_idx + nary_rels = [tuple(x) for x in nary_rels] + return nary_rels + + class RstDepTree(object): """RST dependency tree @@ -50,9 +85,16 @@ class RstDepTree(object): List of the EDUs of this document. origin : Document?, optional TODO - """ + nary_enc : one of {'chain', 'tree'}, optional + Type of encoding used for n-ary relations: 'chain' or 'tree'. + This determines for example how fragmented EDUs are resolved. + """ - def __init__(self, edus=[], origin=None): + def __init__(self, edus=[], origin=None, nary_enc='chain'): + # WIP 2016-07-20 nary_enc to resolve fragmented EDUs + if nary_enc not in ['chain', 'tree']: + raise ValueError("nary_enc must be in {'tree', 'chain'}") + self.nary_enc = nary_enc # FIXME find a clean way to avoid generating a new left padding EDU # here _lpad = EDU.left_padding() @@ -177,17 +219,31 @@ def add_dependencies(self, gov_num, dep_nums, labels=None, nucs=None, # common rank self.ranks[_idx_dep] = rank - def get_dependencies(self): + def get_dependencies(self, lbl_type='rel'): """Get the list of dependencies in this dependency tree. Each dependency is a 3-uple (gov, dep, label), gov and dep being EDUs. + + Parameters + ---------- + lbl_type: one of {'rel', 'rel+nuc'} (TODO 'rel+nuc+rnk'?) + Type of the labels. """ + if lbl_type not in ['rel', 'rel+nuc']: + raise ValueError("lbl_type needs to be one of {'rel', 'rel+nuc'}") + edus = self.edus deps = self.edus[1:] gov_idxs = self.heads[1:] - labels = self.labels[1:] + if lbl_type == 'rel': + labels = self.labels[1:] + elif lbl_type == 'rel+nuc': + labels = list(zip(self.labels[1:], + ['N' + nuc[0] for nuc in self.nucs[1:]])) + else: + raise NotImplementedError("WIP") result = [(edus[gov_idx], dep, lbl) for gov_idx, dep, lbl @@ -217,6 +273,26 @@ def deps(self, gov_idx): sorted_deps = [i for rk, i in ranked_deps] return sorted_deps + def fragmented_edus(self): + """Get the fragmented EDUs in this RST tree. + + Fragmented EDUs are made of two or more EDUs linked by + "same-unit" relations. + + Returns + ------- + frag_edus: list of tuple of int + Each fragmented EDU is given as a tuple of the indices of + the fragments. + """ + nary_enc = self.nary_enc + su_deps = [(gov_idx, dep_idx) for dep_idx, (gov_idx, lbl) + in enumerate(zip(self.heads[1:], self.labels[1:]), + start=1) + if lbl.lower() == 'same-unit'] + frag_edus = binary_to_nary(nary_enc, su_deps) + return frag_edus + def real_roots_idx(self): """Get the list of the indices of the real roots""" return self.deps(_ROOT_HEAD) @@ -258,7 +334,9 @@ def spans(self): def from_simple_rst_tree(cls, rtree): """Converts a ̀SimpleRSTTree` to an `RstDepTree`""" edus = sorted(rtree.leaves(), key=lambda x: x.span.char_start) - dtree = cls(edus) + # building a SimpleRSTTree requires to binarize the original + # RSTTree first, so 'chain' is the only possibility + dtree = cls(edus, nary_enc='chain') def walk(tree): """ @@ -303,10 +381,19 @@ def walk(tree): return dtree @classmethod - def from_rst_tree(cls, rtree): - """Converts an ̀RSTTree` to an `RstDepTree`""" + def from_rst_tree(cls, rtree, nary_enc='tree'): + """Converts an ̀RSTTree` to an `RstDepTree`. + + Parameters + ---------- + nary_enc : one of {'chain', 'tree'} + If 'chain', the given RSTTree is binarized first. + """ edus = sorted(rtree.leaves(), key=lambda x: x.span.char_start) - dtree = cls(edus) + # if 'chain', binarize the tree first + if nary_enc == 'chain': + rtree = _binarize(rtree) + dtree = cls(edus, nary_enc=nary_enc) def walk(tree): """ diff --git a/educe/rst_dt/document_plus.py b/educe/rst_dt/document_plus.py index d6b606f..5ebab28 100644 --- a/educe/rst_dt/document_plus.py +++ b/educe/rst_dt/document_plus.py @@ -399,24 +399,98 @@ def align_with_trees(self, strict=False): return self - def all_edu_pairs(self): - """Generate all EDU pairs of a document""" + def all_edu_pairs(self, ordered=True): + """Generate all EDU pairs of a document. + + Parameters + ---------- + ordered: boolean, defaults to True + If True, generate all ordered pairs of EDUs, otherwise + (half as many) unordered pairs. + + Returns + ------- + all_pairs: [(EDU, EDU)] + All pairs of EDUs in this document. + """ edus = self.edus - all_pairs = [epair for epair in itertools.product(edus, edus[1:]) - if epair[0] != epair[1]] + if ordered: + all_pairs = [epair for epair in itertools.product(edus, edus[1:]) + if epair[0] != epair[1]] + else: + all_pairs = list(itertools.combinations(edus, 2)) return all_pairs - def relations(self, edu_pairs): - """Get the relation that holds in each of the edu_pairs""" + def relations(self, du_pairs, lbl_type='rel', ordered=True): + """Get the relation that holds in each of the DU pairs. + + As of 2016-09-30, this function has a unique caller: + doc_vectorizer.DocumentLabelExtractor._extract_labels() . + + Parameters + ---------- + du_pairs : [(DU, DU)] + List of DU pairs. + + lbl_type : one of {'rel', 'rel+nuc'} + Type of label. + + ordered : boolean, defaults to True + If True, du_pairs are considered ordered, otherwise the + label of either (edu1, edu2) or (edu2, edu1) is returned (if + not None). + + Returns + ------- + erels : :obj:`list` of :obj:`str` + Relation for each pair of DUs. + """ if not self.deptree: - return [None for epair in edu_pairs] + return [None for epair in du_pairs] - rels = {(src, tgt): rel - for src, tgt, rel in self.deptree.get_dependencies()} - erels = [rels.get(epair, 'UNRELATED') - for epair in edu_pairs] + if ordered: + rels = {(src, tgt): rel + for src, tgt, rel + in self.deptree.get_dependencies(lbl_type=lbl_type)} + else: + # on unordered pairs, if dep < gov we need to invert the + # nuclearity so that it encodes direction of attachment: + # NS/NN for right attachment, SN for left + rels = dict() + for src, tgt, rel in self.deptree.get_dependencies( + lbl_type=lbl_type): + if src.num < tgt.num: + # right attachment + u_pair = (src, tgt) + new_rel = rel + else: + # left attachment + u_pair = (tgt, src) + new_rel = ((rel[0], rel[1][1] + rel[1][0]) + if lbl_type == 'nuc+rel' + else rel) + rels[u_pair] = new_rel + + erels = [rels.get(epair, 'UNRELATED') for epair in du_pairs] return erels + def same_unit_candidates(self): + """Generate all EDU pairs that could be a same-unit. + + We use the following filters: + * right-attachment: i < j, + * same sentence: edu2sent[i] == edu2sent[j], + * len > 1: i + 1 < j + """ + edus = self.edus + edu2sent = self.edu2sent + # combinations() generates right-attachment candidates only + su_cands = [(edus[i], edus[j]) for i, j + in itertools.combinations(range(0, len(edus)), 2) + if (i + 1 < j + and edu2sent[i] == edu2sent[j])] + return su_cands + def set_syn_ctrees(self, tkd_trees, lex_heads=None): """Set syntactic constituency trees for this document. diff --git a/educe/rst_dt/rst_relations.py b/educe/rst_dt/rst_relations.py new file mode 100644 index 0000000..5403490 --- /dev/null +++ b/educe/rst_dt/rst_relations.py @@ -0,0 +1,682 @@ +"""Structured inventory of relations used in the RST-DT. + +This module provides a structured view of the relation labels used in the +RST-DT, using information from the reference manual [rst-dt-manual]_ +and the initial instructions for annotators [rst-dt-instru]_. + +References +---------- +.. [rst-dt-manual] Carlson, L., & Marcu, D. (2001). Discourse tagging reference manual. ISI Technical Report ISI-TR-545, 54, 56. +.. [rst-dt-instru] Marcu, D. (1999). Instructions for manually annotating the discourse structures of texts. Unpublished manuscript, USC/ISI. +""" + +from __future__ import absolute_import, print_function + +import os + +import nltk.tree + +from educe.rst_dt.annotation import SimpleRSTTree +from educe.rst_dt.corpus import Reader +from educe.rst_dt.deptree import RstDepTree +from educe.rst_dt.pseudo_relations import rewrite_pseudo_rels + + +# Inventory of classes of rhetorical relations, from subsection 4.1 of +# the [rst-dt-manual]_. +# It maps 18 classes to 54 "representative members" of the 78 (53 mononuclear, +# 25 multinuclear) used in the RST-DT ; 2 multinuclear relations are in fact +# pseudo-relations: "Same-Unit" and "TextualOrganization", each has its own +# class. +# For completeness, I have added 2 more classes: "span" for c-trees, "root" +# for d-trees. +# Note: +# * attribution-n is a proper relation related to, but distinct from, +# "attribution", the "-n" stands for "attribution-negative" rather than +# "attribution from the nucleus" +RELATION_CLASSES = { + "attribution": ["attribution", "attribution-n"], + "background": ["background", "circumstance"], + "cause": ["cause", "consequence"], # "result" grouped with "cause" below + "comparison": ["comparison", "preference", "analogy", "proportion"], + "condition": ["condition", "hypothetical", "contingency", "otherwise"], + "contrast": ["contrast", "concession"], # "antithesis" with "contrast" + "elaboration": ["elaboration-additional", "elaboration-general-specific", + "elaboration-part-whole", "elaboration-process-step", + "elaboration-object-attribute", "elaboration-set-member", + "example", "definition"], + "enablement": ["purpose", "enablement"], + "evaluation": ["evaluation", "interpretation", "conclusion", "comment"], + "explanation": ["evidence", "explanation-argumentative", "reason"], + "joint": ["list", "disjunction"], + "manner-means": ["manner", "means"], + "topic-comment": ["problem-solution", "question-answer", + "statement-response", "topic-comment", "comment-topic", + "rhetorical-question"], + "summary": ["summary", "restatement"], + "temporal": ["temporal-before", "temporal-after", "temporal-same-time", + "sequence", "inverted-sequence"], + "topic-change": ["topic-shift", "topic-drift"], + # the 25 multinuclear relations include 2 pseudo-relations + "same-unit": ["same-unit"], + "textual": ["textualorganization"], + # add label "span" for completeness, for c-trees + "span": ["span"], + # add label "root" for completeness, for d-trees + "root": ["root"], +} + +# groups of relation labels that differ only in the respective nuclearity +# of their arguments ; +# groups of relations are triples of: +# (Mononuclear-satellite, Mononuclear-nucleus, Multinuclear) +# where 0 to 2 slots can be empty (None) +RELATION_REPRESENTATIVES = { + "analogy": ("analogy", None, "Analogy"), + # "antithesis": ("antithesis", None, "Contrast") # see "contrast" + "attribution": ("attribution", None, None), + "attribution-n": ("attribution-n", None, None), # negative attribution + "background": ("background", None, None), + "cause": ("result", "cause", "Cause-Result"), # "result" moved here + "circumstance": ("circumstance", None, None), + "comparison": ("comparison", None, "Comparison"), + "comment": ("comment", None, None), + "comment-topic": (None, None, "Comment-Topic"), # linear order of args of "topic-comment" reversed + "concession": ("concession", None, None), + "conclusion": ("conclusion", None, "Conclusion"), + "condition": ("condition", None, None), + "consequence": ("consequence-s", "consequence-n", "Consequence"), + "contingency": ("contingency", None, None), + "contrast": ("antithesis", None, "Contrast"), + "definition": ("definition", None, None), + "disjunction": (None, None, "Disjunction"), + "elaboration-additional": ("elaboration-additional", None, None), + "elaboration-set-member": ("elaboration-set-member", None, None), + "elaboration-part-whole": ("elaboration-part-whole", None, None), + "elaboration-process-step": ("elaboration-process-step", None, None), + "elaboration-object-attribute": ("elaboration-object-attribute", None, None), + "elaboration-general-specific": ("elaboration-general-specific", None, None), + "enablement": ("enablement", None, None), + "evaluation": ("evaluation-s", "evaluation-n", "Evaluation"), + "evidence": ("evidence", None, None), + "example": ("example", None, None), + "explanation-argumentative": ("explanation-argumentative", None, None), + "hypothetical": ("hypothetical", None, None), + "interpretation": ("interpretation-s", "interpretation-n", "Interpretation"), + "inverted-sequence": (None, None, "Inverted-Sequence"), + "list": (None, None, "List"), + "manner": ("manner", None, None), + "means": ("means", None, None), + "otherwise": ("otherwise", None, "Otherwise"), + "preference": ("preference", None, None), + "problem-solution": ("problem-solution-s", "problem-solution-n", "Problem-Solution"), + "proportion": (None, None, "Proportion"), + "purpose": ("purpose", None, None), + "question-answer": ("question-answer-s", "question-answer-n", "Question-Answer"), + "reason": ("reason", None, "Reason"), + "restatement": ("restatement", None, None), + # "result": ("cause", "result", "Cause-Result") # see "cause" + "rhetorical-question": ("rhetorical-question", None, None), + "same-unit": (None, None, "Same-Unit"), + "sequence": (None, None, "Sequence"), + "statement-response": ("statement-response-s", "statement-response-n", "Statement-Response"), + "summary": ("summary-s", "summary-n", None), + "temporal-before": (None, "temporal-before", None), + "temporal-same-time": ("temporal-same-time", "temporal-same-time", "Temporal-Same-Time"), + "temporal-after": (None, "temporal-after", None), + "textualorganization": (None, None, "TextualOrganization"), + "topic-comment": (None, None, "Topic-Comment"), + "topic-drift": ("topic-drift", None, "Topic-Drift"), + "topic-shift": ("topic-shift", None, "Topic-Shift"), + # for completeness (maybe useless) + "span": "span", +} + +# other, less populated dimensions or similarities can link relations: +# * "antithesis" differs from "concession": the latter is characterized by a +# violated expectation, +# * "attribution-n" is an "attribution" with a negation (negations like "not" +# but also semantically negative verbs like "deny") in the source +# (satellite) +# * "background" is weaker than "circumstance": often, events in "background" +# happen at distinct times whereas events in "circumstance" are somewhat +# co-temporal, +# * "cause-result" and "consequence" ("Consequence" ~ "Cause-Result", +# "consequence-n" ~ "result" and "consequence-s" ~ "cause") are similar, +# the former are for when the causality is perceived as more direct while +# the latter are for more indirect causal relation. +# * "comment" could be confused with "evaluation" and "interpretation" +# * "Comment-Topic" and "Topic-Comment" are the same relation but the linear +# order of their arguments is reversed (Comment then Topic or the other way +# around) +# * "comparison" could be confused with "contrast", but the latter typically +# contains a contrastive discourse cue (ex: but, however, while) while the +# former does not +# * "consequence-n" is similar to "result", "consequence-s" +# * the satellite of "elaboration-process-step" is usually realized as a +# multinuclear "Sequence" relation +# * the satellite of "elaboration-set-member" can be a multinuclear "List" +# relation where each member elaborates on part of the nucleus +# * "example" should be chosen rather than "elaboration-set-member" if not +# the other members of the set are not known or specified +# * "explanation-argumentative" differs from "evidence" in that the writer +# has no intention to convince the reader of a point in the former, and +# it differs from "reason" because the latter involves the will or +# intentions of the agent (hence the agent must be animate) +# * "hypothetical" presents a more abstract scenario than "condition" +# * "inverted-sequence" is "sequence" with elements in reverse chronological +# order +# * "List" is for situations where "comparison", "contrast", or other +# multinuclear relations +# * "manner" is less "goal-oriented" than "means", describes more the style +# of an action +# * "preference" compares two situations/acts/events/... and assigns a clear +# preference for one of those +# * "purpose" differs from "result" in that the satellite is only putative +# (yet to be achieved) in the former, factual (achieved) in the latter; +# can be confused with "elaboration-object-attribute-e" but the latter +# can modify a noun phrase as a relative +# * "restatement" just reiterates the info with slightly different wording, +# as opposed to e.g. interpretation +# * "temporal-before" is for mononuclear relations, usually the satellite is +# realized as a subordinate clause that follows the nucleus ; +# if the second (in the linear order) event happens before the first but +# the relation is multinuclear, use "Inverted-Sequence". +# * "temporal-after" is for mononuclear relations (see "temporal-before") ; +# for multinuclear relations with e1 < e2 use "Sequence". +# * "topic-shift" differs from "topic-drift": in the latter, the same elements +# are in focus whereas it is not the case in the former + + +# embedded relations explicitly present in the annotation guide +EMBEDDED_RELATIONS = [ + "elaboration-additional-e", + "elaboration-object-attribute-e", + "elaboration-set-member-e", + "interpretation-s-e", + "manner-e", +] + + +# mapping from fine-grained to coarse-grained relation labels +FINE_TO_COARSE = { + "analogy": "comparison", + "analogy-e": "comparison", + "antithesis": "contrast", + "antithesis-e": "contrast", + "attribution": "attribution", + "attribution-e": "attribution", + "attribution-n": "attribution", # stands for "attribution-negative" !! + "background": "background", + "background-e": "background", + "cause": "cause", # missing from prev version of mapping (!?) + "cause-e": "cause", # origin? corpus: NO, Joty's map: NO + "cause-result": "cause", + "circumstance": "background", + "circumstance-e": "background", + "comment": "evaluation", + "comment-e": "evaluation", + "comment-topic": "topic-comment", + "comparison": "comparison", + "comparison-e": "comparison", + "concession": "contrast", + "concession-e": "contrast", + "conclusion": "evaluation", + "condition": "condition", + "condition-e": "condition", + "consequence": "cause", + "consequence-n": "cause", + "consequence-n-e": "cause", + "consequence-s": "cause", + "consequence-s-e": "cause", + "contingency": "condition", + "contrast": "contrast", + "definition": "elaboration", + "definition-e": "elaboration", + "disjunction": "joint", + "elaboration-additional": "elaboration", + "elaboration-additional-e": "elaboration", + "elaboration-e": "elaboration", # origin? corpus: NO, Joty's map: NO + "elaboration-general-specific": "elaboration", + "elaboration-general-specific-e": "elaboration", + "elaboration-object-attribute": "elaboration", + "elaboration-object-attribute-e": "elaboration", + "elaboration-part-whole": "elaboration", + "elaboration-part-whole-e": "elaboration", + "elaboration-process-step": "elaboration", + "elaboration-process-step-e": "elaboration", + "elaboration-set-member": "elaboration", + "elaboration-set-member-e": "elaboration", + "enablement": "enablement", + "enablement-e": "enablement", + "evaluation": "evaluation", + "evaluation-n": "evaluation", + "evaluation-s": "evaluation", + "evaluation-s-e": "evaluation", + "evidence": "explanation", + "evidence-e": "explanation", + "example": "elaboration", + "example-e": "elaboration", + "explanation-argumentative": "explanation", + "explanation-argumentative-e": "explanation", + "hypothetical": "condition", + "interpretation": "evaluation", + "interpretation-n": "evaluation", + "interpretation-s": "evaluation", + "interpretation-s-e": "evaluation", + "inverted-sequence": "temporal", + "list": "joint", + "manner": "manner-means", + "manner-e": "manner-means", + "means": "manner-means", + "means-e": "manner-means", + "otherwise": "condition", + "preference": "comparison", + "preference-e": "comparison", + "problem-solution": "topic-comment", + "problem-solution-n": "topic-comment", + "problem-solution-s": "topic-comment", + "proportion": "comparison", + "purpose": "enablement", + "purpose-e": "enablement", + "question-answer": "topic-comment", + "question-answer-n": "topic-comment", + "question-answer-s": "topic-comment", + "reason": "explanation", + "reason-e": "explanation", + "restatement": "summary", + "restatement-e": "summary", + "result": "cause", + "result-e": "cause", + "rhetorical-question": "topic-comment", + "same-unit": "same-unit", # pseudo-rel + "sequence": "temporal", + "statement-response": "topic-comment", + "statement-response-n": "topic-comment", + "statement-response-s": "topic-comment", + "summary-n": "summary", + "summary-s": "summary", + "temporal-after": "temporal", + "temporal-after-e": "temporal", + "temporal-before": "temporal", + "temporal-before-e": "temporal", + "temporal-same-time": "temporal", + "temporal-same-time-e": "temporal", + "textualorganization": "textual", # pseudo-rel + "topic-comment": "topic-comment", + "topic-comment-n": "topic-comment", # origin? corpus: NO, Joty's map: NO + "topic-comment-s": "topic-comment", # origin? corpus: NO, Joty's map: NO + "topic-drift": "topic-change", + "topic-shift": "topic-change", +} + +# TODO test that we have the same mapping here and in Joty's file +joty_map = dict() +with open('/home/mmorey/melodi/joty/parsing_eval_metrics/RelationClasses.txt') as f: + for line in f: + line = line.strip() + if not line: + continue + fields = line.split(':') + if len(fields) != 2: + print(line) + raise ValueError('gni') + coarse_lbl = fields[0].strip().lower() + fine_lbls = [x.strip() for x in fields[1].split(', ')] + for fine_lbl in fine_lbls: + joty_map[fine_lbl] = coarse_lbl + +print(sorted(set(FINE_TO_COARSE.items()) - set(joty_map.items()))) +print(sorted(set(joty_map.items()) - set(FINE_TO_COARSE.items()))) +# assert set(FINE_TO_COARSE.items()) == set(relmap.items()) +# FIXME: comparison between our mapping and Joty's reveals 2 differences: +# * 1 major: "comment: evaluation" (ours) vs "comment: topic-comment" (Joty) +# * 1 minor: "textualorganization: textual" (ours) vs +# "textualorganization: textualorganization" (joty) + +# Examples of TextualOrganization blocks: +# - dateline: wsj_1105: "CHICAGO -", wsj_1377: "SMYRNA, Ga. --" +# - byline: (lots of examples) + +# add "root" label for dependency trees +FINE_TO_COARSE["root"] = "root" +# add "span" label (?) for constituency trees +FINE_TO_COARSE["span"] = "span" + +RST_RELS_FINE = sorted(FINE_TO_COARSE.keys()) +RST_RELS_COARSE = sorted(set(FINE_TO_COARSE.values())) + + +# WIP +# relative to the educe docs directory +# was: DATA_DIR = '/home/muller/Ressources/' +DATA_DIR = os.path.join( + os.path.dirname(__file__), + '..', '..', + 'data', # alt: '..', '..', 'corpora' +) +RST_DIR = os.path.join(DATA_DIR, 'rst_discourse_treebank', 'data') +RST_CORPUS = { + 'train': os.path.join(RST_DIR, 'RSTtrees-WSJ-main-1.0', 'TRAINING'), + 'test': os.path.join(RST_DIR, 'RSTtrees-WSJ-main-1.0', 'TEST'), + 'double': os.path.join(RST_DIR, 'RSTtrees-WSJ-double-1.0'), +} + +rst_corpus_dir = RST_CORPUS['train'] +rst_reader = Reader(rst_corpus_dir) +rst_corpus = rst_reader.slurp(verbose=True) +ctrees = [doc for doc_key, doc in sorted(rst_corpus.items())] + +for doc_key, ctree in sorted(rst_corpus.items()): + rewrite_pseudo_rels(doc_key, ctree) + +raise ValueError('WIP TextualOrganization and Topic-Shift') + +# "chain" transform from ctree to dtree (via SimpleRSTTree) +dtrees = [RstDepTree.from_simple_rst_tree(SimpleRSTTree.from_rst_tree(doc)) + for doc_key, doc in sorted(rst_corpus.items())] + + +# get dependencies +def get_dependencies(dtree): + """Get dependency triplets from a dependency tree""" + return [(gov_idx, dep_idx, lbl) + for dep_idx, (gov_idx, lbl) in enumerate( + zip(dtree.heads[1:], dtree.labels[1:]), + start=1)] + + +# examine Same-Unit relations: search for governors on the right +def check_su_right_gov(dtree): + """TODO""" + all_deps = get_dependencies(dtree) + su_roots = set(gov_idx + for gov_idx, dep_idx, lbl + in all_deps + if lbl == 'Same-Unit') + su_right_govs = [(gov_idx, dep_idx, lbl) + for gov_idx, dep_idx, lbl + in all_deps + if (dep_idx in su_roots and + gov_idx > dep_idx)] + if su_roots: + print('W: {}\t{} out of {} Same-Unit roots have a right governor'.format( + dtree.origin.doc, len(su_right_govs), len(su_roots))) + for gov_idx, dep_idx, lbl in su_right_govs: + print('\t{}\t{}\t{}'.format( + gov_idx, dep_idx, dtree.labels[dep_idx])) + return su_right_govs + + +# get same-unit pairs, from dtree or ctree +def same_units_deps_from_dtree(dtree): + """Get same unit dependencies from a dependency tree""" + return [(gov_idx, dep_idx) + for dep_idx, gov_idx, lbl in get_dependencies(dtree) + if lbl == 'Same-Unit'] + + +def same_units_deps_from_ctree(ctree): + """Get same unit dependencies from a constituency tree""" + su_pairs = [] # result + + tree_posits = ctree.treepositions() + for tpos in tree_posits: + node = ctree[tpos] + if not isinstance(node, nltk.tree.Tree): # skip leaf nodes + continue + same_units = [(i, x) for i, x in enumerate(node) + if (isinstance(x, nltk.tree.Tree) and + x.label().rel == 'Same-Unit')] + lmost_leaves = [] + for i, x in same_units: + # compare the leftmost leaf of the leftmost nucleus + # with the recursive leftmost nucleus + # * leftmost leaf of the leftmost nucleus + if ((len(x) == 1 and + not isinstance(x[0], nltk.tree.Tree))): + lmost_leaf = x[0] + else: + lmost_nuc = [y for y in x + if (isinstance(y, nltk.tree.Tree) and + y.label().nuclearity == 'Nucleus')][0] + lmost_leaf = lmost_nuc.leaves()[0] + lmost_leaves.append(lmost_leaf.num) + # generate dependencies according to the "chain" transform + su_pairs.extend([(gov_idx, dep_idx) for gov_idx, dep_idx in + zip(lmost_leaves, lmost_leaves[1:])]) + return su_pairs + +# (pseudo-)relations used to impose a tree structure: +# span, Same-Unit, TextualOrganization + +# 1. span +# only in ctrees + +# 2. TextualOrganization + +# 3. Same-Unit +# * pseudo-relation +# * the intervening material is attached to one of the constituents, +# usually the first one, but might be the second one if it is more +# appropriate (e.g. for relation that links two events, when the event +# is in the second fragment) +# +# The problem of spurious Same-Unit is not particular to the RST-DT, +# it also shows in the Discourse Graphbank: +# http://www.aclweb.org/anthology/W10-4311 +# Interestingly, this article reveals systematic deviations in the DG +# corpus compared to the definition of the Same relation, by looking +# at cases where the same text is also part of the RST corpus but +# the RST annotation uses another relation than "Same-Unit". +# While this comparison between DG and RST uses the RST treebank as +# a reference, we show that the RST corpus also contains +# inconsistencies. +# +# For another study on Same-Unit in the RST corpus: +# https://www.seas.upenn.edu/~pdtb/papers/BanikLeeLREC08.pdf + +# Marcu 2000: +# * p. 167, on another, related corpus (30 MUC7 coref, 30 Brown-Learned, +# 30 WSJ): +# rhetorical relations + "two constituency relations that were ubiquitous +# in the corpora and that often subsumed complex rhetorical constituents, +# and one textual relation. The constituency relations were `attribution`, +# which was used to label the relation between a reporting and a reported +# clause, and `apposition`. The textual relation was `TextualOrganization`; +# it was used to connect in an RST-like manner the textual spans that +# corresponded to the title, author, and textual body of each document in +# the corpus." + + +def same_units_adjacent(dtree): + """Same-Units where the fragments are adjacent""" + res = [] + for dep_idx, (gov_idx, lbl) in enumerate( + zip(dtree.heads[1:], dtree.labels[1:]), + start=1): + if lbl == 'Same-Unit' and dep_idx - gov_idx == 1: + res.append((gov_idx, dep_idx)) + print('W:', gov_idx, dep_idx) + return res + + +def same_units_both_inside_deps(dtree): + """Same-Units where both fragments govern intervening material""" + res = [] + for dep_idx, (gov_idx, lbl) in enumerate( + zip(dtree.heads[1:], dtree.labels[1:]), + start=1): + if lbl == 'Same-Unit': + gov_deps_i = [x for x in dtree.deps(gov_idx) + if x > gov_idx and x < dep_idx] + dep_deps_i = [x for x in dtree.deps(dep_idx) + if x > gov_idx and x < dep_idx] + if gov_deps_i and dep_deps_i: + res.append((gov_idx, dep_idx)) + return res + + +def same_units_different_sentences(dtree): + """Same-Units where the two fragments belong to different sentences + + (Not yet implemented) + """ + # TODO + return [] + + +def same_units_second_has_inside_attribution(dtree): + """Same-Units with intervening 'attribution' headed by frag2.""" + res = [] + for dep_idx, (gov_idx, lbl) in enumerate( + zip(dtree.heads[1:], dtree.labels[1:]), + start=1): + if lbl == 'Same-Unit': + # frag2 has intervening direct dependents "attribution" + dep_deps_i = [x for x in dtree.deps(dep_idx) + if (x > gov_idx and x < dep_idx and + dtree.labels[x].startswith('attribution'))] + # TODO extend to transitive dependents ? + if dep_deps_i: + res.append((gov_idx, dep_idx)) + return res + + +# note from Feng's PhD thesis: +# their model fails on Topic-Change, Textual-Organization, Topic-Comment, +# Evaluation, because they look "more abstractly defined" => candidate for +# post-proc? impact of WMD on these spans? + +def check_same_units_ctree(ctree): + """Check structural properties of "Same-Unit" fragments. + """ + tree_posits = ctree.treepositions() + for tpos in tree_posits: + node = ctree[tpos] + if not isinstance(node, nltk.tree.Tree): # skip leaf nodes + continue + same_units = [(i, x) for i, x in enumerate(node) + if (isinstance(x, nltk.tree.Tree) and + x.label().rel == 'Same-Unit')] + # weird same-units: n-ary (n>2) same-units + if len(same_units) > 2: + print(ctree.origin.doc, '\tn-ary same-unit\t', + [x[1].label() for x in same_units]) + # weird same-units: nucleus != span[0] + for i, x in same_units: + # compare the leftmost leaf of the leftmost nucleus + # with the recursive leftmost nucleus + # * leftmost leaf of the leftmost nucleus + if ((len(x) == 1 and + not isinstance(x[0], nltk.tree.Tree))): + lmost_leaf = x[0] + else: + lmost_nuc = [y for y in x + if (isinstance(y, nltk.tree.Tree) and + y.label().nuclearity == 'Nucleus')][0] + lmost_leaf = lmost_nuc.leaves()[0] + + # * (recursively found) nucleus of the leftmost nucleus + nuc_cand = x + while True: + if ((len(nuc_cand) == 1 and + not isinstance(nuc_cand[0], nltk.tree.Tree))): + nuc_cand = nuc_cand[0] + break + # else recurse + nuc_cands = [y for y in nuc_cand + if (isinstance(y, nltk.tree.Tree) and + y.label().nuclearity == 'Nucleus')] + if len(nuc_cands) > 1: + print(ctree.origin.doc, '\t>1 nucleus\t', x.label()) + nuc_cand = nuc_cands[0] + if lmost_leaf != nuc_cand: + print(ctree.origin.doc, '\tlmost_leaf != nucleus\t', x.label()) + su_groups = [] # TODO + return su_groups + + +# attribution +if False: + for dt in dtrees: + check_su_right_gov(dt) + raise ValueError('Hop SU right govs') + +# Same-Unit +if False: # run check functions + for x in ctrees: + check_same_units_ctree(x) + + +if True: # get same-unit deps from d and c trees + same_units_nb = 0 + with open('/home/mmorey/melodi/rst_same_unit_suspects_clean1', 'wb') as f: + for ct, dt in zip(ctrees, dtrees): + # typical pathological cases: + # * both gov and dep have deps inside span + both_inside = set(same_units_both_inside_deps(dt)) + # * adjacent fragments + len_one = set(same_units_adjacent(dt)) + # * intervening EDU headed by frag2 + intervening_dep2 = set(same_units_second_has_inside_attribution(dt)) + # * different sentences (not implemented yet) + diff_sents = set(same_units_different_sentences(dt)) + # union of Same-Unit weirdos + su_weirdos = sorted(both_inside | len_one | intervening_dep2 | + diff_sents) + doc_name = ct.origin.doc + if su_weirdos: + print('\n'.join('{}\t{}\t{}'.format( + doc_name, x[0], x[1]) + for x in su_weirdos), + file=f) + # total number of Same-Unit dependencies + all_deps = get_dependencies(dt) + all_same_units = set((gov_idx, dep_idx, lbl) + for gov_idx, dep_idx, lbl + in all_deps + if lbl == 'Same-Unit') + same_units_nb += len(all_same_units) + print('Total number of Same-Unit dependencies:', same_units_nb) + +raise ValueError('Check me') +# end WIP + + +def merge_same_units(dtree): + """Merge fragments of EDUs linked by the pseudo-relation Same-Unit. + + Parameters + ---------- + dtree : RstDepTree + Dependency tree. + + Returns + ------- + dtree_merged : RstDepTree? + Dependency tree with merged EDUs instead of Same-Unit. + """ + raise NotImplementedError('TODO implement merge_same_units') + +# SIGDIAL 2001: +# "In addition, three relations are used to impose structure on the tree: +# textual-organization, span, and same-unit (used to link parts of units +# separated by embedded units or spans)." + +# hence, the following relation labels should or can be separated or +# discarded for evaluation: +# * "span" (in consituency trees; obvious), +# * "root" (in dependency trees; obvious), +# * "same-unit" is a pseudo-relation (pretty obvious), +# * "textual-organization" (? check with NA) + +# Joty also has a relation label "dummy", I suspect they serve the same +# purpose as the "ROOT" label from dependency trees + + +# TODO +# 0. ENH function to merge same-unit +# 1. ENH knn classification on EDUs => try to associate a "semantic class" of event to each EDU, then look at the pair of cluster IDs to decide attachment or labels +# 2. FIX STAC features +# 3. ENH MLP diff --git a/educe/rst_dt/util/cmd/check_tokenization.py b/educe/rst_dt/util/cmd/check_tokenization.py new file mode 100644 index 0000000..94b6c34 --- /dev/null +++ b/educe/rst_dt/util/cmd/check_tokenization.py @@ -0,0 +1,121 @@ +"""Compare tokenization between PTB and CoreNLP for the RST-WSJ corpus. + +""" + +from __future__ import print_function + +import os + +import numpy as np + +from nltk.corpus.reader import BracketParseCorpusReader + +from educe.external.stanford_xml_reader import PreprocessingSource +from educe.rst_dt.corenlp import read_corenlp_result +from educe.rst_dt.corpus import Reader +from educe.rst_dt.deptree import RstDepTree +from educe.rst_dt.document_plus import DocumentPlus +from educe.rst_dt.ptb import PtbParser + + +DATA_DIR = 'data' +PTB_DIR = os.path.join(DATA_DIR, 'PTBIII/parsed/mrg/wsj') +RST_DIR = os.path.join(DATA_DIR, 'rst_discourse_treebank/data') +CORENLP_OUT_DIR = os.path.join(DATA_DIR, 'rst_discourse_treebank', '..', + 'rst-dt-corenlp-2015-01-29') + + +if __name__ == '__main__': + if not os.path.exists(PTB_DIR): + raise ValueError("Unable to find PTB dir {}".format(PTB_DIR)) + if not os.path.exists(RST_DIR): + raise ValueError("Unable to find RST dir {}".format(RST_DIR)) + if not os.path.exists(CORENLP_OUT_DIR): + raise ValueError("Unable to find parsed dir {}".format( + CORENLP_OUT_DIR)) + + corpus = 'RSTtrees-WSJ-main-1.0/TRAINING' + corpus_dir = os.path.join(RST_DIR, corpus) + # syntactic parsers to compare + ptb_reader = BracketParseCorpusReader(PTB_DIR, + r'../wsj_.*\.mrg', + encoding='ascii') + # read the RST corpus + rst_reader = Reader(corpus_dir) + rst_corpus = rst_reader.slurp() + # for each file, compare tokenizations between PTB and CoreNLP + for key, rst_tree in sorted(rst_corpus.items()): + doc_name = key.doc.split('.', 1)[0] + if doc_name.startswith('wsj_'): + print(doc_name) + doc_wsj_num = doc_name.split('_')[1] + section = doc_wsj_num[:2] + + # corenlp stuff + core_fname = os.path.join(CORENLP_OUT_DIR, corpus, + doc_name + '.out.xml') + core_reader = PreprocessingSource() + core_reader.read(core_fname, suffix='') + corenlp_doc = read_corenlp_result(None, core_reader) + core_toks = corenlp_doc.tokens + core_toks_beg = [x.span.char_start for x in core_toks] + core_toks_end = [x.span.char_end for x in core_toks] + + # PTB stuff + # * create DocumentPlus (adapted from educe.rst_dt.corpus) + rst_context = rst_tree.label().context + ptb_docp = DocumentPlus(key, doc_name, rst_context) + # * attach EDUs (yerk) + # FIXME we currently get them via an RstDepTree created from + # the original RSTTree, so as to get the left padding EDU + rst_dtree = RstDepTree.from_rst_tree(rst_tree) + ptb_docp.edus = rst_dtree.edus + # * setup a PtbParser (re-yerk) + ptb_parser = PtbParser(PTB_DIR) + ptb_parser.tokenize(ptb_docp) + # get PTB toks ; skip left padding token + ptb_toks = ptb_docp.tkd_tokens[1:] + ptb_toks_beg = ptb_docp.toks_beg[1:] + ptb_toks_end = ptb_docp.toks_end[1:] + + # compare ! + core2ptb_beg = np.searchsorted(ptb_toks_beg, core_toks_beg, + side='left') + core2ptb_end = np.searchsorted(ptb_toks_end, core_toks_end, + side='right') - 1 + # TODO maybe use np.diff? + mism_idc = np.where(core2ptb_beg != core2ptb_end)[0] + # group consecutive indices where beg != end + mismatches = ([(mism_idc[0], mism_idc[0])] if mism_idc.any() + else []) + for elt_cur, elt_nxt in zip(mism_idc[:-1], mism_idc[1:]): + if elt_nxt > elt_cur + 1: + # new mismatch + mismatches.append((elt_nxt, elt_nxt)) + else: # elt_nxt == elt_cur + 1 + # extend current mismatch + mismatches[-1] = (mismatches[-1][0], elt_nxt) + # print mismatches + for core_beg, core_end in mismatches: + m_core_toks = core_toks[core_beg:core_end + 1] # DEBUG + ptb_beg = core2ptb_beg[core_beg] + ptb_end = core2ptb_end[core_end] + n_ptb_toks = ptb_toks[ptb_beg:ptb_end + 1] + if not n_ptb_toks: + print('* Text missing from PTB:', + '({}, {}) '.format( + m_core_toks[0].span.char_start, + m_core_toks[-1].span.char_end), + ' '.join(x.word for x in m_core_toks)) + elif not m_core_toks: + print('* Text missing from RST:', + '({}, {}) '.format( + n_ptb_toks[0].span.char_start, + n_ptb_toks[-1].span.char_end), + ' '.join(x.word for x in n_ptb_toks)) + else: + print('* Mismatch', + '\nCore >>>\t', ' '.join( + unicode(x) for x in m_core_toks), + '\nPTB <<<\t', ' '.join( + unicode(x) for x in n_ptb_toks)) diff --git a/educe/rst_dt/util/cmd/deptree.py b/educe/rst_dt/util/cmd/deptree.py index 8bcf81e..ca4f38f 100644 --- a/educe/rst_dt/util/cmd/deptree.py +++ b/educe/rst_dt/util/cmd/deptree.py @@ -8,15 +8,13 @@ from __future__ import print_function import os -from educe.rst_dt.deptree import RstDepTree +from educe.rst_dt.annotation import SimpleRSTTree from educe.rst_dt.dep2con import deptree_to_simple_rst_tree -import educe.rst_dt +from educe.rst_dt.deptree import RstDepTree -from ..args import\ - add_usual_input_args, add_usual_output_args,\ - read_corpus, get_output_dir, announce_output_dir -from .reltypes import\ - empty_counts, walk_and_count +from ..args import (add_usual_input_args, add_usual_output_args, + read_corpus, get_output_dir, announce_output_dir) +from .reltypes import empty_counts, walk_and_count NAME = 'deptree' @@ -49,7 +47,7 @@ def convert(corpus, multinuclear, odir): for k in corpus: suffix = os.path.splitext(k.doc)[0] - stree = educe.rst_dt.SimpleRSTTree.from_rst_tree(corpus[k]) + stree = SimpleRSTTree.from_rst_tree(corpus[k]) with open(os.path.join(bin_dir, suffix), 'w') as fout: fout.write(str(stree)) @@ -57,7 +55,7 @@ def convert(corpus, multinuclear, odir): with open(os.path.join(dt_dir, suffix), 'w') as fout: fout.write(str(dtree)) - stree2 = deptree_to_simple_rst_tree(dtree, multinuclear) + stree2 = deptree_to_simple_rst_tree(dtree) with open(os.path.join(rst2_dir, suffix), 'w') as fout: fout.write(str(stree2)) diff --git a/educe/rst_dt/util/cmd/text.py b/educe/rst_dt/util/cmd/text.py index 5738338..5bc86f9 100644 --- a/educe/rst_dt/util/cmd/text.py +++ b/educe/rst_dt/util/cmd/text.py @@ -6,11 +6,11 @@ """ from __future__ import print_function + import os -from ..args import\ - add_usual_input_args, add_usual_output_args,\ - read_corpus, get_output_dir, announce_output_dir +from ..args import (add_usual_input_args, add_usual_output_args, + read_corpus, get_output_dir, announce_output_dir) NAME = 'text' diff --git a/educe/rst_dt/util/cmd/tmp.py b/educe/rst_dt/util/cmd/tmp.py index 7b646a3..e6df7a8 100644 --- a/educe/rst_dt/util/cmd/tmp.py +++ b/educe/rst_dt/util/cmd/tmp.py @@ -8,9 +8,8 @@ from __future__ import print_function from educe.internalutil import treenode -from ..args import\ - add_usual_input_args, add_usual_output_args,\ - read_corpus, get_output_dir, announce_output_dir +from ..args import (add_usual_input_args, add_usual_output_args, + read_corpus, get_output_dir, announce_output_dir) NAME = 'tmp' diff --git a/educe/stac/lexicon/inquirer.py b/educe/stac/lexicon/inquirer.py new file mode 100644 index 0000000..3bb3ccf --- /dev/null +++ b/educe/stac/lexicon/inquirer.py @@ -0,0 +1,39 @@ +"""Load the Inquirer lexicon. + +This code used to live in `educe.stac.learning.features` ; to the best +of my knowledge it is not used anywhere in the current codebase but who +knows? +""" + +from collections import defaultdict +import re + +from educe.learning.educe_csv_format import SparseDictReader + + +def read_inquirer_lexicon(inq_txt_file, classes): + """Read and return the local Inquirer lexicon. + + Parameters + ---------- + inq_txt_file : string + Path to the local text version of the Inquirer. + + classes : list of string + List of classes from the Inquirer that should be included. + + Returns + ------- + words : dict(string, string) + Map from each class to its list of words. + """ + with open(inq_txt_file) as cin: + creader = SparseDictReader(cin, delimiter='\t') + words = defaultdict(list) + for row in creader: + for k in row: + word = row["Entry"].lower() + word = re.sub(r'#.*$', r'', word) + if k in classes: + words[k].append(word) + return words From 5f62f46336af69aaafc7e2f323e2828220e68f82 Mon Sep 17 00:00:00 2001 From: moreymat Date: Tue, 31 Jan 2017 16:31:51 +0100 Subject: [PATCH 14/44] WIP educe.stac document-centric feature extraction --- educe/stac/learning/cmd/extract.py | 39 +- educe/stac/learning/cmd/res_nps.py | 18 +- educe/stac/learning/doc_vectorizer.py | 652 ++++++++++++++++++++++++- educe/stac/learning/features.py | 660 +++----------------------- 4 files changed, 741 insertions(+), 628 deletions(-) diff --git a/educe/stac/learning/cmd/extract.py b/educe/stac/learning/cmd/extract.py index e677bd3..e64a00b 100755 --- a/educe/stac/learning/cmd/extract.py +++ b/educe/stac/learning/cmd/extract.py @@ -10,6 +10,7 @@ from __future__ import print_function from os import path as fp +import itertools import os import sys @@ -17,7 +18,11 @@ from educe.stac.annotation import (DIALOGUE_ACTS, SUBORDINATING_RELATIONS, COORDINATING_RELATIONS) -from educe.stac.learning import features +from educe.stac.learning.doc_vectorizer import ( + DialogueActVectorizer, LabelVectorizer, mk_high_level_dialogues, + extract_pair_features, extract_single_features, read_corpus_inputs) + + import educe.corpus from educe.learning.edu_input_format import (dump_all, labels_comment, @@ -29,12 +34,6 @@ import educe.stac import educe.util -from ..doc_vectorizer import (DialogueActVectorizer, - LabelVectorizer) -from ..features import (strip_cdus, - mk_high_level_dialogues, - extract_pair_features, - extract_single_features) NAME = 'extract' @@ -84,12 +83,9 @@ def config_argparser(parser): def main_single(args): """Extract feature vectors for single EDUs in the corpus.""" - inputs = features.read_corpus_inputs(args) + inputs = read_corpus_inputs(args) stage = 'unannotated' if args.parsing else 'units' dialogues = list(mk_high_level_dialogues(inputs, stage)) - # these paths should go away once we switch to a proper dumper - out_file = fp.join(args.output, - fp.basename(args.corpus) + '.dialogue-acts.sparse') instance_generator = lambda x: x.edus[1:] # drop fake root # pylint: disable=invalid-name @@ -102,8 +98,12 @@ def main_single(args): labtor = DialogueActVectorizer(instance_generator, DIALOGUE_ACTS) y_gen = labtor.transform(dialogues) + # create directory structure if not fp.exists(args.output): os.makedirs(args.output) + # these paths should go away once we switch to a proper dumper + out_file = fp.join(args.output, + fp.basename(args.corpus) + '.dialogue-acts.sparse') # list dialogue acts comment = labels_comment(labtor.labelset_) @@ -120,19 +120,16 @@ def main_single(args): def main_pairs(args): """Extract feature vectors for pairs of EDUs in the corpus.""" - inputs = features.read_corpus_inputs(args) + inputs = read_corpus_inputs(args) stage = 'units' if args.parsing else 'discourse' dialogues = list(mk_high_level_dialogues(inputs, stage)) - # these paths should go away once we switch to a proper dumper - out_file = fp.join(args.output, - fp.basename(args.corpus) + '.relations.sparse') instance_generator = lambda x: x.edu_pairs() labels = frozenset(SUBORDINATING_RELATIONS + COORDINATING_RELATIONS) # pylint: disable=invalid-name - # scikit-convention + # X, y follow the naming convention in sklearn feats = extract_pair_features(inputs, stage) vzer = KeyGroupVectorizer() if args.parsing or args.vocabulary: @@ -145,14 +142,14 @@ def main_pairs(args): zero=args.parsing) y_gen = labtor.transform(dialogues) + # create directory structure if not fp.exists(args.output): os.makedirs(args.output) + # these paths should go away once we switch to a proper dumper + out_file = fp.join(args.output, + fp.basename(args.corpus) + '.relations.sparse') - dump_all(X_gen, - y_gen, - out_file, - labtor.labelset_, - dialogues, + dump_all(X_gen, y_gen, out_file, labtor.labelset_, dialogues, instance_generator) # dump vocabulary vocab_file = out_file + '.vocab' diff --git a/educe/stac/learning/cmd/res_nps.py b/educe/stac/learning/cmd/res_nps.py index 47d66ac..848c747 100644 --- a/educe/stac/learning/cmd/res_nps.py +++ b/educe/stac/learning/cmd/res_nps.py @@ -10,29 +10,25 @@ from __future__ import print_function from collections import defaultdict, namedtuple -from itertools import chain import csv import sys from educe.stac import postag, corenlp from educe.stac.annotation import is_edu -from educe.stac.learning import features -from educe.util import\ - add_corpus_filters, fields_without, mk_is_interesting,\ - concat, concat_l +from educe.stac.learning.doc_vectorizer import (mk_env, get_players, + FeatureInput, LexWrapper) +from educe.stac.learning.features import enclosed_trees, is_nplike +from educe.util import (add_corpus_filters, fields_without, mk_is_interesting, + concat, concat_l) import educe.corpus import educe.glozz import educe.learning.keys import educe.stac -from ..features import\ - mk_env, get_players, enclosed_trees, is_nplike,\ - FeatureInput - NAME = 'resource-nps' -LEXICON = features.LexWrapper('domain', 'stac_domain.txt', True) +LEXICON = LexWrapper('domain', 'stac_domain.txt', True) def nplike_trees(current, edu): @@ -130,10 +126,10 @@ def config_argparser(parser): add_corpus_filters(parser, fields=fields_without(["stage"])) parser.set_defaults(func=main) + # --------------------------------------------------------------------- # main # --------------------------------------------------------------------- - def _read_corpus_inputs(args): """ Read and filter the part of the corpus we want features for diff --git a/educe/stac/learning/doc_vectorizer.py b/educe/stac/learning/doc_vectorizer.py index 7ccc990..d275e7c 100644 --- a/educe/stac/learning/doc_vectorizer.py +++ b/educe/stac/learning/doc_vectorizer.py @@ -2,7 +2,29 @@ # pylint: disable=too-few-public-methods -from .features import clean_dialogue_act +from __future__ import absolute_import, print_function + +from collections import defaultdict, namedtuple +import itertools +import copy +import os +import sys + +from nltk.corpus import verbnet as vnet + +from educe.learning.keys import KeyGroup +from educe.stac import postag, corenlp +from educe.stac.annotation import addressees, speaker, is_relation_instance +from educe.stac.corpus import twin_key +from educe.stac.fusion import Dialogue, FakeRootEDU, fuse_edus +from educe.stac.learning.features import (clean_dialogue_act, SingleEduKeys, + PairKeys) +# import educe.stac.lexicon.inquirer as inquirer +import educe.stac.lexicon.pdtb_markers as pdtb_markers +from educe.stac.lexicon.wordclass import Lexicon +import educe.stac.graph as stac_gr +import educe.util +import educe.stac UNK = '__UNK__' @@ -27,23 +49,48 @@ def transform(self, raw_documents): """Learn the label encoder and return a vector of labels There is one label per instance extracted from raw_documents. + + Parameters + ---------- + raw_documents : list of `educe.stac.fusion.Dialogue` + List of dialogues. + + Yields + ------ + inst_lbls : list of int + (Integer) label for each instance of the next document. """ # run through documents to generate y for doc in raw_documents: + inst_lbls = [] for edu in self.instance_generator(doc): label = clean_dialogue_act(edu.dialogue_act() or UNK) - yield self.labelset_[label] + inst_lbls.append(self.labelset_[label]) + yield inst_lbls class LabelVectorizer(object): - """Label extractor for the STAC corpus.""" + """Label extractor for the STAC corpus. - def __init__(self, instance_generator, labels, zero=False): - """ - instance_generator to enumerate the instances from a doc + Parameters + ---------- + instance_generator : fun(doc) -> :obj:`list` of (EDU, EDU) + Function to enumerate the instances from a doc. - :type labels: set(string) - """ + labels : :obj:`set` of str + Labelset + + zero : boolean, defaults to False + If True, emit zero for all instances. + + Attributes + ---------- + labelset_ : dict(str, int) + Map from labels to integers ; a few values are reserved: + {UNK: 0, ROOT: 1, UNRELATED: 2}. + """ + + def __init__(self, instance_generator, labels, zero=False): self.instance_generator = instance_generator self.labelset_ = {l: i for i, l in enumerate(labels, start=3)} self.labelset_[UNK] = 0 @@ -55,10 +102,597 @@ def transform(self, raw_documents): """Learn the label encoder and return a vector of labels There is one label per instance extracted from raw_documents. + + Parameters + ---------- + raw_documents : list of ? + Raw documents. + + Yields + ------ + inst_lbls : list of int + (Integer) label for each instance of the next document. """ zlabel = UNK if self._zero else UNRELATED # run through documents to generate y for doc in raw_documents: + inst_lbls = [] for pair in self.instance_generator(doc): label = doc.relations.get(pair, zlabel) - yield self.labelset_[label] + inst_lbls.append(self.labelset_[label]) + yield inst_lbls + + +# moved from educe.stac.learning.features +# FIXME refactor into a proper, consistent API: +# this code does a mix of responsibilities from DocumentPlus and other stuff + +# --------------------------------------------------------------------- +# lexicon configuration +# --------------------------------------------------------------------- +class LexWrapper(object): + """ + Configuration options for a given lexicon: where to find it, + what to call it, what sorts of results to return + """ + + def __init__(self, key, filename, classes): + """ + Note: classes=True means we want the (sub)-class of the lexical + item found, and not just a general boolean + """ + self.key = key + self.filename = filename + self.classes = classes + self.lexicon = None + + def read(self, lexdir): + """ + Read and store the lexicon as a mapping from words to their + classes + """ + path = os.path.join(lexdir, self.filename) + self.lexicon = Lexicon.read_file(path) + + +LEXICONS = [ + LexWrapper('domain', 'stac_domain.txt', True), + LexWrapper('robber', 'stac_domain2.txt', False), + LexWrapper('trade', 'trade.txt', True), + LexWrapper('dialog', 'dialog.txt', False), + LexWrapper('opinion', 'opinion.txt', False), + LexWrapper('modifier', 'modifiers.txt', False), + # hand-extracted from trade prediction code, could + # perhaps be merged with one of the other lexicons + # fr.irit.stac.features.CalculsTraitsTache3 + LexWrapper('pronoun', 'pronouns.txt', True), + LexWrapper('ref', 'stac_referential.txt', False) +] + +# PDTB markers +PDTB_MARKERS_BASENAME = 'pdtb_markers.txt' + +# VerbNet +VerbNetEntry = namedtuple("VerbNetEntry", "classname lemmas") + +VERBNET_CLASSES = ['steal-10.5', + 'get-13.5.1', + 'give-13.1-1', + 'want-32.1-1-1', + 'want-32.1', + 'exchange-13.6-1'] + +# Inquirer +INQUIRER_BASENAME = 'inqtabs.txt' + +INQUIRER_CLASSES = ['Positiv', + 'Negativ', + 'Pstv', + 'Ngtv', + 'NegAff', + 'PosAff', + 'If', + 'TrnGain', # maybe gain/loss words related + 'TrnLoss', # ...transactions + 'TrnLw', + 'Food', # related to Catan resources? + 'Tool', # related to Catan resources? + 'Region', # related to Catan game? + 'Route'] # related to Catan game + + +# --------------------------------------------------------------------- +# preprocessing +# --------------------------------------------------------------------- +def player_addresees(edu): + """ + The set of people spoken to during an edu annotation. + This excludes known non-players, like 'All', or '?', or 'Please choose...', + """ + addr1 = addressees(edu) or frozenset() + return frozenset(x for x in addr1 if x not in ['All', '?']) + + +def players_for_doc(corpus, kdoc): + """ + Return the set of speakers/addressees associated with a document. + + In STAC, documents are semi-arbitrarily cut into sub-documents for + technical and possibly ergonomic reasons, ie. meaningless as far as we are + concerned. So to find all speakers, we would have to search all the + subdocuments of a single document. :: + + (Corpus, String) -> Set String + """ + speakers = set() + docs = [corpus[k] for k in corpus if k.doc == kdoc] + for doc in docs: + for anno in doc.units: + if educe.stac.is_turn(anno): + turn_speaker = speaker(anno) + if turn_speaker: + speakers.add(turn_speaker) + elif educe.stac.is_edu(anno): + speakers.update(player_addresees(anno)) + return frozenset(speakers) + + +# --------------------------------------------------------------------- +# feature extraction +# --------------------------------------------------------------------- + +# The comments on these named tuples can be docstrings in Python3, +# or we can wrap the class, but eh... + +# feature extraction environment +DocEnv = namedtuple("DocEnv", "inputs current sf_cache") + +# Global resources and settings used to extract feature vectors +FeatureInput = namedtuple('FeatureInput', + ['corpus', 'postags', 'parses', + 'lexicons', 'pdtb_lex', + 'verbnet_entries', + 'inquirer_lex']) + +# A document and relevant contextual information +DocumentPlus = namedtuple('DocumentPlus', + ['key', + 'doc', + 'unitdoc', # equiv doc from units + 'players', + 'parses']) + + +# --------------------------------------------------------------------- +# (single) feature cache +# --------------------------------------------------------------------- +class FeatureCache(dict): + """ + Cache for single edu features. + Retrieving an item from the cache lazily computes/memoises + the single EDU features for it. + """ + def __init__(self, inputs, current): + self.inputs = inputs + self.current = current + super(FeatureCache, self).__init__() + + def __getitem__(self, edu): + if edu.identifier() == ROOT: + return KeyGroup('fake root group', []) + elif edu in self: + return super(FeatureCache, self).__getitem__(edu) + else: + vec = SingleEduKeys(self.inputs) + vec.fill(self.current, edu) + self[edu] = vec + return vec + + def expire(self, edu): + """ + Remove an edu from the cache if it's in there + """ + if edu in self: + del self[edu] + + +# --------------------------------------------------------------------- +# extraction generators +# --------------------------------------------------------------------- +def _get_unit_key(inputs, key): + """ + Given the key for what is presumably a discourse level or + unannotated document, return the key for its unit-level + equivalent. + """ + if key.annotator is None: + twins = [k for k in inputs.corpus if + k.doc == key.doc and + k.subdoc == key.subdoc and + k.stage == 'units'] + return twins[0] if twins else None + else: + twin = copy.copy(key) + twin.stage = 'units' + return twin if twin in inputs.corpus else None + + +def mk_env(inputs, people, key): + """Pre-process and bundle up a representation of the current doc. + + Parameters + ---------- + inputs : FeatureInput + Global information for feature extraction. + + people : dict(str, set(str)) + Set of people involved in the dialogue, for each game (map from + document name to set of players). + + key : FileId + Document identifier. + + Returns + ------- + doc_env : DocEnv + Representation of the designated document, ready for feature + extraction. + """ + doc = inputs.corpus[key] + unit_key = _get_unit_key(inputs, key) + current = DocumentPlus(key=key, doc=doc, + unitdoc=(inputs.corpus[unit_key] if unit_key + else None), + players=people[key.doc], + parses=(inputs.parses[key] if inputs.parses + else None)) + + return DocEnv(inputs=inputs, current=current, + sf_cache=FeatureCache(inputs, current)) + + +def get_players(inputs): + """ + Return a dictionary mapping each document to the set of + players in that document + """ + kdocs = frozenset(k.doc for k in inputs.corpus) + return {x: players_for_doc(inputs.corpus, x) + for x in kdocs} + + +def relation_dict(doc, quiet=False): + """ + Return the relations instances from a document in the + form of an id pair to label dictionary + + If there is more than one relation between a pair of + EDUs we pick one of them arbitrarily and ignore the + other + """ + relations = {} + for rel in doc.relations: + if not is_relation_instance(rel): + # might be the odd Anaphora link lying around + continue + pair = rel.source.identifier(), rel.target.identifier() + if pair not in relations: + relations[pair] = rel.type + elif not quiet: + print(('Ignoring {type1} relation instance ({edu1} -> {edu2}); ' + 'another of type {type2} already exists' + '').format(type1=rel.type, + edu1=pair[0], + edu2=pair[1], + type2=relations[pair]), + file=sys.stderr) + # generate fake root links + for anno in doc.units: + if not educe.stac.is_edu(anno): + continue + is_target = False + for rel in doc.relations: + if rel.target == anno: + is_target = True + break + if not is_target: + key = ROOT, anno.identifier() + relations[key] = ROOT + return relations + + +def _extract_pair(env, edu1, edu2): + """Extract features for a given pair of EDUs. + + Directional, so would have to be called twice. + """ + vec = PairKeys(env.inputs, sf_cache=env.sf_cache) + vec.fill(env.current, edu1, edu2) + return vec + + +def _mk_high_level_dialogues(current): + """Helper to generate dialogues. + + Parameters + ---------- + current : educe.stac.fusion.DocumentPlus + Bundled representation of a document. + + Yields + ------- + dia : `educe.stac.fusion.Dialogue` + Next dialogue + """ + doc = current.doc # this is a GlozzDocument + # first pass: create the EDU objects + edus = sorted([x for x in doc.units if educe.stac.is_edu(x)], + key=lambda y: y.span) + edus_in_dialogues = defaultdict(list) + for edu in edus: + edus_in_dialogues[edu.dialogue].append(edu) + + # finally, generate the high level dialogues + relations = relation_dict(doc) + dialogues = sorted(edus_in_dialogues, key=lambda x: x.span) + for dia in dialogues: + d_edus = edus_in_dialogues[dia] + d_relations = {} + for edu1, edu2 in itertools.product([FakeRootEDU] + d_edus, d_edus): + id_pair = (edu1.identifier(), edu2.identifier()) + rel = relations.get(id_pair) + if rel is not None: + d_relations[(edu1, edu2)] = rel + yield Dialogue(dia, d_edus, d_relations) + + +def mk_envs(inputs, stage): + """Generate an environment for each document in the corpus + within the given stage. + + The environment pools together all the information we + have on a single document. + + Parameters + ---------- + inputs : FeatureInput + Global information used for feature extraction. + + stage : one of {'units', 'discourse', 'unannotated'} + Annotation stage + + Yields + ------- + env : DocEnv + Next environment for feature extraction, one per doc. + """ + people = get_players(inputs) + for key in inputs.corpus: + if key.stage != stage: + continue + yield mk_env(inputs, people, key) + + +def mk_high_level_dialogues(inputs, stage): + """Generate all relevant EDU pairs for each designated document. + + Parameters + ---------- + inputs : FeatureInput + Named tuple of global resources and settings used to extract feature + vectors. + + stage : string, one of {'units', 'discourse'} + Stage of annotation + + Yields + ------- + dia : `educe.stac.fusion.Dialogue` + Next dialogue in the Dialogues + """ + for env in mk_envs(inputs, stage): + for dia in _mk_high_level_dialogues(env.current): + yield dia + + +def extract_pair_features(inputs, stage): + """Generator of feature vectors, one per pair of EDUs, in each dialogue. + + Parameters + ---------- + inputs : FeatureInput + Global parameters for feature extraction. + + stage : string, one of {'unannotated', 'units', 'discourse'} + Annotation stage. + + Yields + ------ + vecs : PairEduKeys + List of feature vectors for the EDUs in the next document (here, + STAC dialogue). + """ + for env in mk_envs(inputs, stage): + for dia in _mk_high_level_dialogues(env.current): + vecs = [] + for edu1, edu2 in dia.edu_pairs(): + vec = _extract_pair(env, edu1, edu2) + vecs.append(vec) + yield vecs + + +# --------------------------------------------------------------------- +# extraction generators (single edu) +# --------------------------------------------------------------------- +def extract_single_features(inputs, stage, safety_check=True): + """Generator of feature vectors, one per EDU, in each dialogue. + + Parameters + ---------- + inputs : FeatureInput + Global parameters for feature extraction. + + stage : string, one of {'unannotated', 'units', 'discourse'} + Annotation stage. + + Yields + ------ + vecs : SingleEduKeys + List of feature vectors for the EDUs in the next document (in + fact, STAC subdoc). + """ + for env in mk_envs(inputs, stage): + doc = env.current.doc + if safety_check: + # safety check: ensure all EDUs are processed + # compare EDUs collected initially from the document, with + # all EDUs processed during feature extraction + all_edus = sorted([x for x in doc.units if educe.stac.is_edu(x)], + key=lambda y: y.span) + processed_edus = [] + # skip any documents which are not yet annotated + if env.current.unitdoc is None: + continue + # 2016-01-12 generate one list per Dialogue, rather than one per + # subdoc + for dia in _mk_high_level_dialogues(env.current): + edus = dia.edus[1:] # dia.edus[0]: left padding (fake root) EDU + vecs = [] + for edu in edus: + vec = SingleEduKeys(env.inputs) + vec.fill(env.current, edu) + vecs.append(vec) + processed_edus.append(edu) # safety check + yield vecs + if safety_check: + # safety check: final step + processed_edus = sorted(processed_edus, key=lambda y: y.span) + assert processed_edus == all_edus + + +# --------------------------------------------------------------------- +# input readers +# --------------------------------------------------------------------- +def mk_is_interesting(args, single): + """ + Return a function that filters corpus keys to pick out the ones + we specified on the command line + + We have two cases here: for pair extraction, we just want to + grab the units and if possible the discourse stage. In live mode, + there won't be a discourse stage, but that's fine because we can + just fall back on units. + + For single extraction (dialogue acts), we'll also want to grab the + units stage and fall back to unannotated when in live mode. This + is made a bit trickier by the fact that unannotated does not have + an annotator, so we have to accomodate that. + + Phew. + + It's a bit specific to feature extraction in that here we are + trying + + :type single: bool + """ + if single: + # ignore annotator filter for unannotated documents + args1 = copy.copy(args) + args1.annotator = None + is_interesting1 = educe.util.mk_is_interesting( + args1, preselected={'stage': ['unannotated']}) + # but pay attention to it for units + args2 = args + is_interesting2 = educe.util.mk_is_interesting( + args2, preselected={'stage': ['units']}) + return lambda x: is_interesting1(x) or is_interesting2(x) + else: + preselected = {"stage": ["discourse", "units"]} + return educe.util.mk_is_interesting(args, preselected=preselected) + + +def _fuse_corpus(corpus, postags): + "Merge any dialogue/unit level documents together" + to_delete = [] + for key in corpus: + if key.stage == 'unannotated': + # slightly abusive use of fuse_edus to just get the effect of + # having EDUs that behave like contexts + # + # context: feature extraction for live mode dialogue acts + # extraction, so by definition we don't have a units stage + corpus[key] = fuse_edus(corpus[key], corpus[key], postags[key]) + elif key.stage == 'units': + # similar Context-only abuse of fuse-edus (here, we have a units + # stage but no dialogue to make use of) + # + # context: feature extraction for + # - live mode discourse parsing (by definition we don't have a + # discourse stage yet, but we might have a units stage + # inferred earlier in the parsing pipeline) + # - dialogue act annotation from corpus + corpus[key] = fuse_edus(corpus[key], corpus[key], postags[key]) + elif key.stage == 'discourse': + ukey = twin_key(key, 'units') + corpus[key] = fuse_edus(corpus[key], corpus[ukey], postags[key]) + to_delete.append(ukey) + for key in to_delete: + del corpus[key] + + +def read_corpus_inputs(args): + """Read and filter the part of the corpus we want features for. + + Parameters + ---------- + args : Namespace + Arguments given to the arg parser, in the form of a Namespace + produced by `ArgumentParser.parse_args()`. + + Returns + ------- + feat_input : FeatureInput + Named tuple of global resources and settings used to extract feature + vectors. + """ + reader = educe.stac.Reader(args.corpus) + anno_files = reader.filter(reader.files(), + mk_is_interesting(args, args.single)) + corpus = reader.slurp(anno_files, verbose=True) + + # optional: strip CDUs from the `GlozzDocument`s in the corpus + if not args.ignore_cdus: + # for all documents in the corpus, remove any CDUs and relink the + # document according to the desired mode + # this is performed on a graph model of the document: + # `educe.stac.Graph.strip_cdus()` mutates the graph's doc + for key in corpus: + graph = stac_gr.Graph.from_doc(corpus, key) + graph.strip_cdus(sloppy=True, mode=args.strip_mode) + + # read predicted POS tags, syntactic parse, coreferences etc. + postags = postag.read_tags(corpus, args.corpus) + parses = corenlp.read_results(corpus, args.corpus) + _fuse_corpus(corpus, postags) + + # read our custom lexicons + for lex in LEXICONS: + lex.read(args.resources) + # read lexicon PDTB discourse markers + pdtb_lex_file = os.path.join(args.resources, PDTB_MARKERS_BASENAME) + pdtb_lex = pdtb_markers.read_lexicon(pdtb_lex_file) + # read Inquirer lexicon (disabled) + # inq_txt_file = os.path.join(args.resources, INQUIRER_BASENAME) + # inq_lex = inquirer.read_inquirer_lexicon(inq_txt_file, INQUIRER_CLASSES) + inq_lex = {} + + verbnet_entries = [VerbNetEntry(x, frozenset(vnet.lemmas(x))) + for x in VERBNET_CLASSES] + + return FeatureInput(corpus=corpus, + postags=postags, + parses=parses, + lexicons=LEXICONS, + pdtb_lex=pdtb_lex, + verbnet_entries=verbnet_entries, + inquirer_lex=inq_lex) diff --git a/educe/stac/learning/features.py b/educe/stac/learning/features.py index 621293c..ff40e0a 100644 --- a/educe/stac/learning/features.py +++ b/educe/stac/learning/features.py @@ -7,41 +7,29 @@ """ from __future__ import absolute_import, print_function -from collections import defaultdict, namedtuple, Sequence +from collections import namedtuple, Sequence from functools import wraps -import copy import itertools -import os import re import sys -from nltk.corpus import verbnet as vnet from soundex import Soundex -from educe.annotation import (Span) -from educe.external.parser import\ - SearchableTree,\ - ConstituencyTree -from educe.learning.keys import (MagicKey, Key, KeyGroup, MergedKeyGroup) -from educe.stac import postag, corenlp -from educe.stac.annotation import speaker, addressees, is_relation_instance -from educe.stac.context import (enclosed, - edus_in_span, - turns_in_span) -from educe.stac.corpus import (twin_key) -from educe.learning.educe_csv_format import SparseDictReader, tune_for_csv +from educe.annotation import Span +from educe.external.parser import SearchableTree, ConstituencyTree +from educe.learning.keys import MagicKey, Key, KeyGroup, MergedKeyGroup +from educe.stac.annotation import speaker +from educe.stac.context import enclosed, edus_in_span, turns_in_span +from educe.learning.educe_csv_format import tune_for_csv from educe.learning.util import tuple_feature, underscore import educe.corpus import educe.glozz import educe.stac import educe.stac.lexicon.pdtb_markers as pdtb_markers -import educe.stac.graph as stac_gr import educe.util from ..annotation import turn_id -from ..lexicon.wordclass import Lexicon -from ..fusion import (Dialogue, ROOT, FakeRootEDU, - fuse_edus) +from ..fusion import ROOT, FakeRootEDU class CorpusConsistencyException(Exception): @@ -56,94 +44,9 @@ def __init__(self, msg): super(CorpusConsistencyException, self).__init__(msg) -# --------------------------------------------------------------------- -# lexicon configuration -# --------------------------------------------------------------------- - - -class LexWrapper(object): - """ - Configuration options for a given lexicon: where to find it, - what to call it, what sorts of results to return - """ - - def __init__(self, key, filename, classes): - """ - Note: classes=True means we want the (sub)-class of the lexical - item found, and not just a general boolean - """ - self.key = key - self.filename = filename - self.classes = classes - self.lexicon = None - - def read(self, lexdir): - """ - Read and store the lexicon as a mapping from words to their - classes - """ - path = os.path.join(lexdir, self.filename) - self.lexicon = Lexicon.read_file(path) - - -LEXICONS = [LexWrapper('domain', 'stac_domain.txt', True), - LexWrapper('robber', 'stac_domain2.txt', False), - LexWrapper('trade', 'trade.txt', True), - LexWrapper('dialog', 'dialog.txt', False), - LexWrapper('opinion', 'opinion.txt', False), - LexWrapper('modifier', 'modifiers.txt', False), - # hand-extracted from trade prediction code, could - # perhaps be merged with one of the other lexicons - # fr.irit.stac.features.CalculsTraitsTache3 - LexWrapper('pronoun', 'pronouns.txt', True), - LexWrapper('ref', 'stac_referential.txt', False)] - -PDTB_MARKERS_BASENAME = 'pdtb_markers.txt' - -VerbNetEntry = namedtuple("VerbNetEntry", "classname lemmas") - -VERBNET_CLASSES = ['steal-10.5', - 'get-13.5.1', - 'give-13.1-1', - 'want-32.1-1-1', - 'want-32.1', - 'exchange-13.6-1'] - -INQUIRER_BASENAME = 'inqtabs.txt' - -INQUIRER_CLASSES = ['Positiv', - 'Negativ', - 'Pstv', - 'Ngtv', - 'NegAff', - 'PosAff', - 'If', - 'TrnGain', # maybe gain/loss words related - 'TrnLoss', # ...transactions - 'TrnLw', - 'Food', # related to Catan resources? - 'Tool', # related to Catan resources? - 'Region', # related to Catan game? - 'Route'] # related to Catan game - - -# --------------------------------------------------------------------- -# preprocessing -# --------------------------------------------------------------------- -def strip_cdus(corpus, mode): - """ - For all documents in a corpus, remove any CDUs and relink the - document according to the desired mode. This mutates the corpus. - """ - for key in corpus: - graph = stac_gr.Graph.from_doc(corpus, key) - graph.strip_cdus(sloppy=True, mode=mode) - # --------------------------------------------------------------------- # relation queries # --------------------------------------------------------------------- - - def emoticons(tokens): "Given some tokens, return just those which are emoticons" return frozenset(token for token in tokens if token.tag == 'E') @@ -156,15 +59,6 @@ def is_just_emoticon(tokens): return bool(emoticons(tokens)) and len(tokens) == 1 -def player_addresees(edu): - """ - The set of people spoken to during an edu annotation. - This excludes known non-players, like 'All', or '?', or 'Please choose...', - """ - addr1 = addressees(edu) or frozenset() - return frozenset(x for x in addr1 if x not in ['All', '?']) - - def position_of_speaker_first_turn(edu): """ Given an EDU context, determine the position of the first turn by that @@ -179,30 +73,6 @@ def position_of_speaker_first_turn(edu): raise CorpusConsistencyException(oops) -def players_for_doc(corpus, kdoc): - """ - Return the set of speakers/addressees associated with a document. - - In STAC, documents are semi-arbitrarily cut into sub-documents for - technical and possibly ergonomic reasons, ie. meaningless as far as we are - concerned. So to find all speakers, we would have to search all the - subdocuments of a single document. :: - - (Corpus, String) -> Set String - """ - speakers = set() - docs = [corpus[k] for k in corpus if k.doc == kdoc] - for doc in docs: - for anno in doc.units: - if educe.stac.is_turn(anno): - turn_speaker = speaker(anno) - if turn_speaker: - speakers.add(turn_speaker) - elif educe.stac.is_edu(anno): - speakers.update(player_addresees(anno)) - return frozenset(speakers) - - def clean_chat_word(token): """ Given a word and its postag (educe PosTag representation) @@ -290,8 +160,8 @@ def prunable(tree): def good(tree): "is within the search span" - return tree.link == "nsubj" and\ - span.encloses(tree.label().text_span()) + return (tree.link == "nsubj" and + span.encloses(tree.label().text_span())) subtrees = map_topdown(good, prunable, trees) return [tree.label().features["lemma"] for tree in subtrees] @@ -320,36 +190,10 @@ def good(tree): return map_topdown(good, prunable, trees) -# --------------------------------------------------------------------- -# feature extraction -# --------------------------------------------------------------------- - -# The comments on these named tuples can be docstrings in Python3, -# or we can wrap the class, but eh... - -# feature extraction environment -DocEnv = namedtuple("DocEnv", "inputs current sf_cache") - -# Global resources and settings used to extract feature vectors -FeatureInput = namedtuple('FeatureInput', - ['corpus', 'postags', 'parses', - 'lexicons', 'pdtb_lex', - 'verbnet_entries', - 'inquirer_lex']) - -# A document and relevant contextual information -DocumentPlus = namedtuple('DocumentPlus', - ['key', - 'doc', - 'unitdoc', # equiv doc from units - 'players', - 'parses']) # --------------------------------------------------------------------- # feature decorators # --------------------------------------------------------------------- - - def type_text(wrapped): """ Given a feature that emits text, clean its output up so to work @@ -383,8 +227,6 @@ def inner(current, edu): # --------------------------------------------------------------------- # # --------------------------------------------------------------------- - - def clean_dialogue_act(act): """ Knock out temporary markers used during corpus annotation @@ -400,8 +242,6 @@ def clean_dialogue_act(act): # --------------------------------------------------------------------- # single EDU non-lexical features # --------------------------------------------------------------------- - - def feat_id(_, edu): "some sort of unique identifier for the EDU" return edu.identifier() @@ -486,8 +326,8 @@ def lemma_subject(current, edu): def is_nplike(anno): "is some sort of NP annotation from a parser" - return isinstance(anno, ConstituencyTree)\ - and anno.label() in ['NP', 'WHNP', 'NNP', 'NNPS'] + return (isinstance(anno, ConstituencyTree) + and anno.label() in ['NP', 'WHNP', 'NNP', 'NNPS']) def has_FOR_np(current, edu): @@ -495,29 +335,31 @@ def has_FOR_np(current, edu): def is_prep_for(anno): "is a node representing for as the prep in a PP" - return isinstance(anno, ConstituencyTree)\ - and anno.label() == 'IN'\ - and len(anno.children) == 1\ - and anno.children[0].features["lemma"] == "for" + return (isinstance(anno, ConstituencyTree) + and anno.label() == 'IN' + and len(anno.children) == 1 + and anno.children[0].features["lemma"] == "for") def is_for_pp_with_np(anno): "is a for PP node (see above) with some NP-like descendant" - return any(is_prep_for(child) for child in anno.children)\ - and anno.topdown(is_nplike, None) + return (any(is_prep_for(child) for child in anno.children) + and anno.topdown(is_nplike, None)) trees = enclosed_trees(edu.text_span(), current.parses.trees) return bool(map_topdown(is_for_pp_with_np, None, trees)) -QUESTION_WORDS = ["what", - "which", - "where", - "when", - "who", - "how", - "why", - "whose"] +QUESTION_WORDS = [ + "what", + "which", + "where", + "when", + "who", + "how", + "why", + "whose" +] def is_question(current, edu): @@ -525,8 +367,8 @@ def is_question(current, edu): def is_sqlike(anno): "is some sort of question" - return isinstance(anno, ConstituencyTree)\ - and anno.label() in ['SBARQ', 'SQ'] + return (isinstance(anno, ConstituencyTree) + and anno.label() in ['SBARQ', 'SQ']) doc = current.doc span = edu.text_span() @@ -592,8 +434,8 @@ def speaker_started_the_dialogue(_, edu): def speaker_already_spoken_in_dialogue(_, edu): "if the speaker for this EDU is the same as that of a\ previous turn in the dialogue" - return position_of_speaker_first_turn(edu) <\ - edu.dialogue_turns.index(edu.turn) + return (position_of_speaker_first_turn(edu) + < edu.dialogue_turns.index(edu.turn)) def speakers_first_turn_in_dialogue(_, edu): @@ -601,17 +443,14 @@ def speakers_first_turn_in_dialogue(_, edu): speaker for this EDU first spoke" return 1 + position_of_speaker_first_turn(edu) + # --------------------------------------------------------------------- # pair features # --------------------------------------------------------------------- - - -#pylint: disable=unused-argument def feat_annotator(current, edu1, edu2): "annotator for the subdoc" anno = current.doc.origin.annotator return "none" if anno is None or anno is "" else anno -#pylint: enable=unused-argument @tuple_feature(underscore) # decorator does the pairing boilerplate @@ -629,7 +468,6 @@ def dialogue_act_pairs(current, _, edu): EduGap = namedtuple("EduGap", "sf_cache inner_edus turns_between") -#pylint: disable=unused-argument def num_edus_between(_current, gap, _edu1, _edu2): "number of intervening EDUs (0 if adjacent)" return len(gap.inner_edus) @@ -670,7 +508,6 @@ def has_inner_question(current, gap, _edu1, _edu2): "if there is an intervening EDU that is a question" return any(gap.sf_cache[x]["is_question"] for x in gap.inner_edus) -#pylint: enable=unused-argument def same_speaker(current, _, edu1, edu2): @@ -686,8 +523,6 @@ def same_turn(current, _, edu1, edu2): # --------------------------------------------------------------------- # single EDU lexical features # --------------------------------------------------------------------- - - class LexKeyGroup(KeyGroup): """ The idea here is to provide a feature per lexical class in the @@ -857,11 +692,10 @@ class MergedLexKeyGroup(MergedKeyGroup): Single-EDU features based on lexical lookup. """ def __init__(self, inputs): - groups =\ - [LexKeyGroup(l) for l in inputs.lexicons] +\ - [PdtbLexKeyGroup(inputs.pdtb_lex), - InquirerLexKeyGroup(inputs.inquirer_lex), - VerbNetLexKeyGroup(inputs.verbnet_entries)] + groups = ([LexKeyGroup(l) for l in inputs.lexicons] + + [PdtbLexKeyGroup(inputs.pdtb_lex), + InquirerLexKeyGroup(inputs.inquirer_lex), + VerbNetLexKeyGroup(inputs.verbnet_entries)]) description = "lexical features" super(MergedLexKeyGroup, self).__init__(description, groups) @@ -874,8 +708,6 @@ def fill(self, current, edu, target=None): # --------------------------------------------------------------------- # single EDU non-lexical feature groups # --------------------------------------------------------------------- - - class SingleEduSubgroup(KeyGroup): """ Abstract keygroup for subgroups of the merged SingleEduKeys. @@ -914,16 +746,18 @@ class SingleEduSubgroup_Token(SingleEduSubgroup): """ def __init__(self): desc = self.__doc__.strip() - keys =\ - [MagicKey.continuous_fn(num_tokens), - MagicKey.discrete_fn(word_first), - MagicKey.discrete_fn(word_last), - MagicKey.discrete_fn(has_player_name_exact)] + keys = [ + MagicKey.continuous_fn(num_tokens), + MagicKey.discrete_fn(word_first), + MagicKey.discrete_fn(word_last), + MagicKey.discrete_fn(has_player_name_exact) + ] if not sys.version > '3': keys.append(MagicKey.discrete_fn(has_player_name_fuzzy)) - keys2 =\ - [MagicKey.discrete_fn(feat_has_emoticons), - MagicKey.discrete_fn(feat_is_emoticon_only)] + keys2 = [ + MagicKey.discrete_fn(feat_has_emoticons), + MagicKey.discrete_fn(feat_is_emoticon_only) + ] keys.extend(keys2) super(SingleEduSubgroup_Token, self).__init__(desc, keys) @@ -933,10 +767,11 @@ class SingleEduSubgroup_Punct(SingleEduSubgroup): def __init__(self): desc = self.__doc__.strip() - keys =\ - [MagicKey.discrete_fn(has_correction_star), - MagicKey.discrete_fn(ends_with_bang), - MagicKey.discrete_fn(ends_with_qmark)] + keys = [ + MagicKey.discrete_fn(has_correction_star), + MagicKey.discrete_fn(ends_with_bang), + MagicKey.discrete_fn(ends_with_qmark) + ] super(SingleEduSubgroup_Punct, self).__init__(desc, keys) @@ -956,7 +791,8 @@ def __init__(self): MagicKey.discrete_fn(turn_follows_gap), MagicKey.continuous_fn(position_in_dialogue), MagicKey.continuous_fn(position_in_game), - MagicKey.continuous_fn(edu_position_in_turn)] + MagicKey.continuous_fn(edu_position_in_turn) + ] super(SingleEduSubgroup_Chat, self).__init__(desc, keys) @@ -967,10 +803,11 @@ class SingleEduSubgroup_Parser(SingleEduSubgroup): def __init__(self): desc = "parser features" - keys =\ - [MagicKey.discrete_fn(lemma_subject), - MagicKey.discrete_fn(has_FOR_np), - MagicKey.discrete_fn(is_question)] + keys = [ + MagicKey.discrete_fn(lemma_subject), + MagicKey.discrete_fn(has_FOR_np), + MagicKey.discrete_fn(is_question) + ] super(SingleEduSubgroup_Parser, self).__init__(desc, keys) @@ -984,8 +821,7 @@ def __init__(self, inputs): SingleEduSubgroup_Punct(), SingleEduSubgroup_Parser(), MergedLexKeyGroup(inputs)] - super(SingleEduKeys, self).__init__("single EDU features", - groups) + super(SingleEduKeys, self).__init__("single EDU features", groups) def fill(self, current, edu, target=None): """ @@ -1027,9 +863,10 @@ def __init__(self, inputs, sf_cache): self.corpus = inputs.corpus self.sf_cache = sf_cache desc = self.__doc__.strip() - keys =\ - [MagicKey.discrete_fn(is_question_pairs), - MagicKey.discrete_fn(dialogue_act_pairs)] + keys = [ + MagicKey.discrete_fn(is_question_pairs), + MagicKey.discrete_fn(dialogue_act_pairs) + ] super(PairSubgroup_Tuple, self).__init__(desc, keys) def fill(self, current, edu1, edu2, target=None): @@ -1047,13 +884,14 @@ class PairSubgroup_Gap(PairSubgroup): def __init__(self, sf_cache): self.sf_cache = sf_cache desc = "the gap between EDUs" - keys =\ - [MagicKey.continuous_fn(num_edus_between), - MagicKey.continuous_fn(num_speakers_between), - MagicKey.continuous_fn(num_nonling_tstars_between), - MagicKey.discrete_fn(same_speaker), - MagicKey.discrete_fn(same_turn), - MagicKey.discrete_fn(has_inner_question)] + keys = [ + MagicKey.continuous_fn(num_edus_between), + MagicKey.continuous_fn(num_speakers_between), + MagicKey.continuous_fn(num_nonling_tstars_between), + MagicKey.discrete_fn(same_speaker), + MagicKey.discrete_fn(same_turn), + MagicKey.discrete_fn(has_inner_question) + ] super(PairSubgroup_Gap, self).__init__(desc, keys) def fill(self, current, edu1, edu2, target=None): @@ -1095,8 +933,7 @@ def __init__(self, inputs, sf_cache=None): self.edu1 = None # will be filled out later self.edu2 = None # from the feature cache - super(PairKeys, self).__init__("pair features", - groups) + super(PairKeys, self).__init__("pair features", groups) def one_hot_values_gen(self, suffix=''): for pair in super(PairKeys, self).one_hot_values_gen(): @@ -1113,354 +950,3 @@ def fill(self, current, edu1, edu2, target=None): vec.edu2 = self.sf_cache[edu2] for group in self.groups: group.fill(current, edu1, edu2, vec) - -# --------------------------------------------------------------------- -# (single) feature cache -# --------------------------------------------------------------------- - - -class FeatureCache(dict): - """ - Cache for single edu features. - Retrieving an item from the cache lazily computes/memoises - the single EDU features for it. - """ - def __init__(self, inputs, current): - self.inputs = inputs - self.current = current - super(FeatureCache, self).__init__() - - def __getitem__(self, edu): - if edu.identifier() == ROOT: - return KeyGroup('fake root group', []) - elif edu in self: - return super(FeatureCache, self).__getitem__(edu) - else: - vec = SingleEduKeys(self.inputs) - vec.fill(self.current, edu) - self[edu] = vec - return vec - - def expire(self, edu): - """ - Remove an edu from the cache if it's in there - """ - if edu in self: - del self[edu] - -# --------------------------------------------------------------------- -# extraction generators -# --------------------------------------------------------------------- - - -def _get_unit_key(inputs, key): - """ - Given the key for what is presumably a discourse level or - unannotated document, return the key for for its unit-level - equivalent. - """ - if key.annotator is None: - twins = [k for k in inputs.corpus if - k.doc == key.doc and - k.subdoc == key.subdoc and - k.stage == 'units'] - return twins[0] if twins else None - else: - twin = copy.copy(key) - twin.stage = 'units' - return twin if twin in inputs.corpus else None - - -def mk_env(inputs, people, key): - """ - Pre-process and bundle up a representation of the current document - """ - doc = inputs.corpus[key] - unit_key = _get_unit_key(inputs, key) - current =\ - DocumentPlus(key=key, - doc=doc, - unitdoc=inputs.corpus[unit_key] if unit_key else None, - players=people[key.doc], - parses=inputs.parses[key] if inputs.parses else None) - - return DocEnv(inputs=inputs, - current=current, - sf_cache=FeatureCache(inputs, current)) - - -def get_players(inputs): - """ - Return a dictionary mapping each document to the set of - players in that document - """ - kdocs = frozenset(k.doc for k in inputs.corpus) - return {x: players_for_doc(inputs.corpus, x) - for x in kdocs} - - -def relation_dict(doc, quiet=False): - """ - Return the relations instances from a document in the - form of an id pair to label dictionary - - If there is more than one relation between a pair of - EDUs we pick one of them arbitrarily and ignore the - other - """ - relations = {} - for rel in doc.relations: - if not is_relation_instance(rel): - # might be the odd Anaphora link lying around - continue - pair = rel.source.identifier(), rel.target.identifier() - if pair not in relations: - relations[pair] = rel.type - elif not quiet: - print(('Ignoring {type1} relation instance ({edu1} -> {edu2}); ' - 'another of type {type2} already exists' - '').format(type1=rel.type, - edu1=pair[0], - edu2=pair[1], - type2=relations[pair]), - file=sys.stderr) - # generate fake root links - for anno in doc.units: - if not educe.stac.is_edu(anno): - continue - is_target = False - for rel in doc.relations: - if rel.target == anno: - is_target = True - break - if not is_target: - key = ROOT, anno.identifier() - relations[key] = ROOT - return relations - - -def _extract_pair(env, edu1, edu2): - """ - Extraction for a given pair of EDUs - (directional, so would have to be called twice) - """ - vec = PairKeys(env.inputs, sf_cache=env.sf_cache) - vec.fill(env.current, edu1, edu2) - return vec - - -def _id_pair(pair): - "pair of ids for pair of edus" - edu1, edu2 = pair - return edu1.identifier(), edu2.identifier() - - -def _mk_high_level_dialogues(current): - """ - Returns - ------- - iterator of `educe.stac.fusion.Dialogue` - """ - doc = current.doc - # first pass: create the EDU objects - edus = sorted([x for x in doc.units if educe.stac.is_edu(x)], - key=lambda x: x.span) - edus_in_dialogues = defaultdict(list) - for edu in edus: - edus_in_dialogues[edu.dialogue].append(edu) - - # finally, generat the high level dialogues - relations = relation_dict(doc) - dialogues = sorted(edus_in_dialogues, key=lambda x: x.span) - for dia in dialogues: - d_edus = edus_in_dialogues[dia] - d_relations = {} - for pair in itertools.product([FakeRootEDU] + d_edus, d_edus): - rel = relations.get(_id_pair(pair)) - if rel is not None: - d_relations[pair] = rel - yield Dialogue(dia, d_edus, d_relations) - - -def mk_envs(inputs, stage): - """ - Generate an environment for each document in the corpus - within the given stage. - - The environment pools together all the information we - have on a single document - """ - people = get_players(inputs) - for key in inputs.corpus: - if key.stage != stage: - continue - yield mk_env(inputs, people, key) - - -def mk_high_level_dialogues(inputs, stage): - """ - Generate all relevant EDU pairs for a document - (generator) - """ - for env in mk_envs(inputs, stage): - for dia in _mk_high_level_dialogues(env.current): - yield dia - - -def extract_pair_features(inputs, stage): - """ - Extraction for all relevant pairs in a document - (generator) - """ - for env in mk_envs(inputs, stage): - for dia in _mk_high_level_dialogues(env.current): - for edu1, edu2 in dia.edu_pairs(): - yield _extract_pair(env, edu1, edu2) - -# --------------------------------------------------------------------- -# extraction generators (single edu) -# --------------------------------------------------------------------- - - -def extract_single_features(inputs, stage): - """ - Return a dictionary for each EDU - """ - for env in mk_envs(inputs, stage): - doc = env.current.doc - # skip any documents which are not yet annotated - if env.current.unitdoc is None: - continue - edus = [unit for unit in doc.units if educe.stac.is_edu(unit)] - for edu in edus: - vec = SingleEduKeys(env.inputs) - vec.fill(env.current, edu) - yield vec - -# --------------------------------------------------------------------- -# input readers -# --------------------------------------------------------------------- - - -def read_pdtb_lexicon(args): - """ - Read and return the local PDTB discourse marker lexicon. - """ - pdtb_lex_file = os.path.join(args.resources, PDTB_MARKERS_BASENAME) - return pdtb_markers.read_lexicon(pdtb_lex_file) - - -def _read_inquirer_lexicon(args): - """ - Read and return the local PDTB discourse marker lexicon. - """ - inq_txt_file = os.path.join(args.resources, INQUIRER_BASENAME) - with open(inq_txt_file) as cin: - creader = SparseDictReader(cin, delimiter='\t') - words = defaultdict(list) - for row in creader: - for k in row: - word = row["Entry"].lower() - word = re.sub(r'#.*$', r'', word) - if k in INQUIRER_CLASSES: - words[k].append(word) - return words - - -def mk_is_interesting(args, single): - """ - Return a function that filters corpus keys to pick out the ones - we specified on the command line - - We have two cases here: for pair extraction, we just want to - grab the units and if possible the discourse stage. In live mode, - there won't be a discourse stage, but that's fine because we can - just fall back on units. - - For single extraction (dialogue acts), we'll also want to grab the - units stage and fall back to unannotated when in live mode. This - is made a bit trickier by the fact that unannotated does not have - an annotator, so we have to accomodate that. - - Phew. - - It's a bit specific to feature extraction in that here we are - trying - - :type single: bool - """ - if single: - # ignore annotator filter for unannotated documents - args1 = copy.copy(args) - args1.annotator = None - is_interesting1 = educe.util.mk_is_interesting( - args1, preselected={'stage': ['unannotated']}) - # but pay attention to it for units - args2 = args - is_interesting2 = educe.util.mk_is_interesting( - args2, preselected={'stage': ['units']}) - return lambda x: is_interesting1(x) or is_interesting2(x) - else: - preselected = {"stage": ["discourse", "units"]} - return educe.util.mk_is_interesting(args, preselected=preselected) - - -def _fuse_corpus(corpus, postags): - "Merge any dialogue/unit level documents together" - to_delete = [] - for key in corpus: - if key.stage == 'unannotated': - # slightly abusive use of fuse_edus to just get the effect of - # having EDUs that behave like contexts - # - # context: feature extraction for live mode dialogue acts - # extraction, so by definition we don't have a units stage - corpus[key] = fuse_edus(corpus[key], corpus[key], postags[key]) - elif key.stage == 'units': - # similar Context-only abuse of fuse-edus (here, we have a units - # stage but no dialogue to make use of) - # - # context: feature extraction for - # - live mode discourse parsing (by definition we don't have a - # discourse stage yet, but we might have a units stage - # inferred earlier in the parsing pipeline) - # - dialogue act annotation from corpus - corpus[key] = fuse_edus(corpus[key], corpus[key], postags[key]) - elif key.stage == 'discourse': - ukey = twin_key(key, 'units') - corpus[key] = fuse_edus(corpus[key], corpus[ukey], postags[key]) - to_delete.append(ukey) - for key in to_delete: - del corpus[key] - - -def read_corpus_inputs(args): - """ - Read and filter the part of the corpus we want features for - """ - reader = educe.stac.Reader(args.corpus) - anno_files = reader.filter(reader.files(), - mk_is_interesting(args, args.single)) - corpus = reader.slurp(anno_files, verbose=True) - - if not args.ignore_cdus: - strip_cdus(corpus, mode=args.strip_mode) - postags = postag.read_tags(corpus, args.corpus) - parses = corenlp.read_results(corpus, args.corpus) - _fuse_corpus(corpus, postags) - - for lex in LEXICONS: - lex.read(args.resources) - pdtb_lex = read_pdtb_lexicon(args) - inq_lex = {} # _read_inquirer_lexicon(args) - - verbnet_entries = [VerbNetEntry(x, frozenset(vnet.lemmas(x))) - for x in VERBNET_CLASSES] - - return FeatureInput(corpus=corpus, - postags=postags, - parses=parses, - lexicons=LEXICONS, - pdtb_lex=pdtb_lex, - verbnet_entries=verbnet_entries, - inquirer_lex=inq_lex) From 199b5341e075b1543394649842bba83b2ab76125 Mon Sep 17 00:00:00 2001 From: moreymat Date: Tue, 31 Jan 2017 17:35:00 +0100 Subject: [PATCH 15/44] WIP educe.rst_dt document-centric feature extraction --- educe/rst_dt/learning/base.py | 2 +- educe/rst_dt/learning/doc_vectorizer.py | 126 +++++++++++++++--------- 2 files changed, 83 insertions(+), 45 deletions(-) diff --git a/educe/rst_dt/learning/base.py b/educe/rst_dt/learning/base.py index a0da658..07f9733 100644 --- a/educe/rst_dt/learning/base.py +++ b/educe/rst_dt/learning/base.py @@ -107,7 +107,7 @@ def lowest_common_parent(treepositions): Parameters ---------- - treepositions : list of tree positions + treepositions : :obj:`list` of tree positions see nltk.tree.Tree.treepositions() Returns diff --git a/educe/rst_dt/learning/doc_vectorizer.py b/educe/rst_dt/learning/doc_vectorizer.py index e8982c1..7dc496c 100644 --- a/educe/rst_dt/learning/doc_vectorizer.py +++ b/educe/rst_dt/learning/doc_vectorizer.py @@ -1,5 +1,7 @@ """This submodule implements document vectorizers""" +from __future__ import print_function + import itertools import numbers @@ -9,15 +11,39 @@ class DocumentLabelExtractor(object): - """Label extractor for the RST-DT treebank.""" + """Label extractor for the RST-DT treebank. + + Parameters + ---------- + instance_generator : generator + Generator that enumerates the instances from a doc. + + ordered_pairs : boolean (default: True) + True if the generated instances are ordered pairs of DUs: + (du1, du2) != (du2, du1). + + unknown_label : str + Reserved label for unknown cases. + + labelset : TODO + TODO + + Attributes + ---------- + fixed_labelset_ : boolean + True if the labelset has been fixed, i.e. `self` has been fit. + + labelset_ : dict + A mapping of labels to indices. + + """ def __init__(self, instance_generator, + ordered_pairs=True, unknown_label='__UNK__', labelset=None): - """ - instance_generator to enumerate the instances from a doc - """ self.instance_generator = instance_generator + self.ordered_pairs = ordered_pairs # 2016-09-30 self.unknown_label = unknown_label self.labelset = labelset @@ -26,18 +52,19 @@ def _extract_labels(self, doc): Parameters ---------- - doc: DocumentPlus + doc : DocumentPlus Rich representation of the document. Returns ------- - labels: list of strings or None + labels : list of strings or None List of labels, one for every pair of EDUs (in the order in which they are generated by `self.instance_generator()`). """ edu_pairs = self.instance_generator(doc) # extract one label per EDU pair - labels = doc.relations(edu_pairs) + # WIP 2016-09-30: ordered + labels = doc.relations(edu_pairs, ordered=self.ordered_pairs) return labels def _instance_labels(self, raw_documents): @@ -49,8 +76,7 @@ def _instance_labels(self, raw_documents): analyze = self.build_analyzer() for doc in raw_documents: doc_labels = analyze(doc) - for lab in doc_labels: - yield labelset.get(lab, unk_lab_id) + yield [labelset.get(lab, unk_lab_id) for lab in doc_labels] def _learn_labelset(self, raw_documents, fixed_labelset): """Learn the labelset""" @@ -140,8 +166,8 @@ def fit_transform(self, raw_documents): if not self.fixed_labelset_: self.labelset_ = labelset # re-run through documents to generate y - for lab in self._instance_labels(raw_documents): - yield lab + for doc_labs in self._instance_labels(raw_documents): + yield doc_labs def transform(self, raw_documents): """Transform documents to a label vector""" @@ -150,8 +176,8 @@ def transform(self, raw_documents): if not self.labelset_: raise ValueError('Empty labelset') - for lab in self._instance_labels(raw_documents): - yield lab + for doc_labs in self._instance_labels(raw_documents): + yield doc_labs # helper function to re-emit features from single EDUs in pairs @@ -235,6 +261,10 @@ def _extract_feature_vectors(self, doc): List of feature vectors, one for every pair of EDUs (in the order in which they are generated by `self.instance_generator()`). + + Notes + ----- + This is a bottleneck for speed. """ doc_preprocess = self.doc_preprocess @@ -371,11 +401,10 @@ def _instances(self, raw_documents): analyze = self.build_analyzer() for doc in raw_documents: feat_vecs = analyze(doc) - for feat_vec in feat_vecs: - row = [(vocabulary[fn], fv) - for fn, fv in feat_vec - if fn in vocabulary] - yield row + doc_rows = [[(vocabulary[fn], fv) for fn, fv in feat_vec + if fn in vocabulary] + for feat_vec in feat_vecs] + yield doc_rows def _vocab_df(self, raw_documents, fixed_vocab): """Gather vocabulary (if fixed_vocab=False) and doc frequency @@ -413,15 +442,19 @@ def _vocab_df(self, raw_documents, fixed_vocab): return vocabulary, vocab_df - def _limit_vocabulary(self, vocabulary, vocab_df, - high=None, low=None, limit=None): + def _limit_features(self, vocab_df, vocabulary, high=None, low=None, + limit=None): """Remove too rare or too common features. Prune features that are non zero in more samples than high or less - documents than low, restrict the vocabulary to at most the limit most - frequent. + documents than low, modifying the vocabulary and restricting it to + (TODO at most the limit most frequent). - Returns the set of removed features. + This does not prune samples with zero features. + + This is essentially a reimplementation of the one in + sklearn.feature_extraction.text.CountVectorizer, except vocab_df + is computed differently. """ if high is None and low is None and limit is None: return set() @@ -448,16 +481,14 @@ def _limit_vocabulary(self, vocabulary, vocab_df, new_indices.append(new_idx) prev_idx = new_idx # removed features - removed_feats = set() vocab_items = vocabulary.items() for feat, old_index in vocab_items: if mask[old_index]: vocabulary[feat] = new_indices[old_index] else: del vocabulary[feat] - removed_feats.add(feat) - return vocabulary, removed_feats + return vocabulary def decode(self, doc): """Decode the input into a DocumentPlus. @@ -496,12 +527,6 @@ def _validate_vocabulary(self): def fit(self, raw_documents, y=None): """Learn a vocabulary dictionary of all features from the documents""" - self.fit_transform(raw_documents) - return self - - def fit_transform(self, raw_documents, y=None): - """Learn the vocabulary dictionary and generate (row, (tgt, src)) - """ self._validate_vocabulary() max_df = self.max_df min_df = self.min_df @@ -522,25 +547,38 @@ def fit_transform(self, raw_documents, y=None): raise ValueError( 'max_df corresponds to < documents than min_df') # limit features with df - vocabulary, rm_feats = self._limit_vocabulary(vocabulary, - vocab_df, - high=max_doc_count, - low=min_doc_count, - limit=max_features) + vocabulary = self._limit_features(vocab_df, vocabulary, + high=max_doc_count, + low=min_doc_count, + limit=max_features) self.vocabulary_ = vocabulary - # re-run through documents to generate X - for row in self._instances(raw_documents): - yield row + return self + + def fit_transform(self, raw_documents, y=None): + """Learn the vocabulary dictionary and generate a feature matrix per document. + """ + self.fit(raw_documents, y=y) + return self.transform(raw_documents) def transform(self, raw_documents): - """Transform documents to a feature matrix + """Transform each document to a feature matrix. + + Generate a feature matrix (one row per instance) for each document. + + Parameters + ---------- + raw_documents : TODO + TODO - Note: generator of (row, (tgt, src)) + Yields + ------ + feat_matrix : (row, (tgt, src)) + Feature matrix for the next document. """ if not hasattr(self, 'vocabulary_'): self._validate_vocabulary() if not self.vocabulary_: raise ValueError('Empty vocabulary') - for row in self._instances(raw_documents): - yield row + for feat_matrix in self._instances(raw_documents): + yield feat_matrix From aab0212d3b33f8a5f10c3db0b93a773c89fa20e4 Mon Sep 17 00:00:00 2001 From: moreymat Date: Wed, 1 Feb 2017 17:52:52 +0100 Subject: [PATCH 16/44] WIP document-centric feature extraction, contd. --- educe/rst_dt/learning/cmd/extract.py | 265 +++++++++++++++++++----- educe/rst_dt/learning/doc_vectorizer.py | 35 ++-- educe/rst_dt/learning/features_dev.py | 41 ++-- educe/stac/learning/cmd/extract.py | 39 ++-- 4 files changed, 285 insertions(+), 95 deletions(-) diff --git a/educe/rst_dt/learning/cmd/extract.py b/educe/rst_dt/learning/cmd/extract.py index ca95920..383fda5 100644 --- a/educe/rst_dt/learning/cmd/extract.py +++ b/educe/rst_dt/learning/cmd/extract.py @@ -9,15 +9,21 @@ """ from __future__ import print_function -import os +from collections import defaultdict +import csv import itertools +from glob import glob +import os +import sys +import time import educe.corpus import educe.glozz import educe.stac import educe.util -from educe.learning.edu_input_format import (dump_all, +from educe.learning.cdu_input_format import dump_all_cdus +from educe.learning.edu_input_format import (dump_all, dump_labels, load_labels) from educe.learning.vocabulary_format import (dump_vocabulary, load_vocabulary) @@ -94,6 +100,22 @@ def config_argparser(parser): parser.add_argument('--experimental', action='store_true', help='Enable experimental features ' '(currently none)') + # 2016-09-12 nary_enc: chain vs tree transform + parser.add_argument('--nary_enc', default='chain', + choices=['chain', 'tree'], + help='Encoding for n-ary nodes') + # end nary_enc + # WIP 2016-07-15 same-unit + parser.add_argument('--instances', + choices=['edu-pairs', 'same-unit'], + default='edu-pairs', + help="Selection of instances") + # end WIP same-unit + # 2016-09-30 enable to choose between unordered and ordered pairs + parser.add_argument('--unordered_pairs', + action='store_true', + help=("Instances are unordered pairs: " + "(src, tgt) == (tgt, src)")) parser.set_defaults(func=main) @@ -101,6 +123,160 @@ def config_argparser(parser): # main # --------------------------------------------------------------------- +def extract_dump_instances(docs, instance_generator, feature_set, + lecsie_data_dir, vocabulary, + split_feat_space, labels, + live, ordered_pairs, output, corpus): + """Extract and dump instances. + + Parameters + ---------- + docs : list of DocumentPlus + Documents + + instance_generator : (string, function) + Instance generator: the first element is a string descriptor of + the instance generator, the second is the instance generator + itself: a function from DocumentPlus to list of EDU pairs. + + vocabulary : filepath + Path to vocabulary + + split_feat_space : string + Splitter for feature space + + labels : filepath? + Path to labelset? + + live : TODO + TODO + + ordered_pairs : boolean + If True, DU pairs (instances) are ordered pairs, i.e. + (src, tgt) <> (tgt, src). + + output : string + Path to the output directory, e.g. 'TMP/data'. + + corpus : TODO + TODO + """ + # get instance generator and its descriptor + instance_descr, instance_gen = instance_generator + + # setup persistency + if not os.path.exists(output): + os.makedirs(output) + if live: + fn_out = 'extracted-features.{}'.format(instance_descr) + else: + fn_out = '{}.relations.{}'.format( + os.path.basename(corpus), instance_descr) + # vocabulary, labels + fn_ext = '.sparse' # our extension for sparse datasets (svmlight) + vocab_file = os.path.join(output, fn_out + fn_ext + '.vocab') + labels_file = os.path.join(output, fn_out + '.labels') + # WIP 2016-08-29 output folder, will contain n files per doc + # ex: TRAINING/wsj_0601.out.relations.all-pairs.sparse + out_dir = os.path.join(output, os.path.basename(corpus)) + if not os.path.exists(out_dir): + os.makedirs(out_dir) + # end WIP output folder + + # extract vectorized samples + if vocabulary is not None: + vocab = load_vocabulary(vocabulary) + min_df = 1 + else: + vocab = None + min_df = 5 + + vzer = DocumentCountVectorizer(instance_gen, + feature_set, + lecsie_data_dir=lecsie_data_dir, + min_df=min_df, + vocabulary=vocab, + split_feat_space=split_feat_space) + # pylint: disable=invalid-name + # X, y follow the naming convention in sklearn + if vocabulary is not None: + X_gen = vzer.transform(docs) + else: + X_gen = vzer.fit_transform(docs) + # pylint: enable=invalid-name + + # extract class label for each instance + if live: + y_gen = itertools.repeat(0) + else: + if labels is not None: + labelset = load_labels(labels) + else: + labelset = None + labtor = DocumentLabelExtractor(instance_gen, + ordered_pairs=ordered_pairs, + labelset=labelset) + if labels is not None: + labtor.fit(docs) + y_gen = labtor.transform(docs) + else: + # y_gen = labtor.fit_transform(rst_corpus) + # fit then transform enables to get classes_ for the dump + labtor.fit(docs) + y_gen = labtor.transform(docs) + + # dump instances to files + for doc, X, y in itertools.izip(docs, X_gen, y_gen): + # dump EDUs and features in svmlight format + doc_name = doc.key.doc + # TODO refactor + if live: + fn_out = 'extracted-features.{}{}'.format( + instance_descr, fn_ext) + else: + fn_out = '{}.relations.{}{}'.format( + doc_name, instance_descr, fn_ext) + out_file = os.path.join(out_dir, fn_out) + # end TODO refactor + dump_all(X, y, out_file, doc, instance_gen) + + # dump labelset + if labels is not None: + # relative path to get a correct symlink + existing_labels = os.path.relpath( + labels, start=os.path.dirname(labels_file)) + # c/c from attelo.harness.util.force_symlink() + if os.path.islink(labels_file): + os.unlink(labels_file) + elif os.path.exists(labels_file): + oops = ("Can't force symlink from " + labels + + " to " + labels_file + + " because a file of that name already exists") + raise ValueError(oops) + os.symlink(existing_labels, labels_file) + # end c/c + else: + dump_labels(labtor.labelset_, labels_file) + + # dump vocabulary + if vocabulary is not None: + # FIXME relative path to get a correct symlink + existing_vocab = os.path.relpath( + vocabulary, start=os.path.dirname(vocab_file)) + # c/c from attelo.harness.util.force_symlink() + if os.path.islink(vocab_file): + os.unlink(vocab_file) + elif os.path.exists(vocab_file): + oops = ("Can't force symlink from " + vocabulary + + " to " + vocab_file + + " because a file of that name already exists") + raise ValueError(oops) + os.symlink(existing_vocab, vocab_file) + # end c/c + else: + dump_vocabulary(vzer.vocabulary_, vocab_file) + + def main(args): "main for feature extraction mode" # retrieve parameters @@ -117,6 +293,7 @@ def main(args): rst_reader = RstDtParser(args.corpus, args, coarse_rels=args.coarse, fix_pseudo_rels=args.fix_pseudo_rels, + nary_enc=args.nary_enc, exclude_file_docs=exclude_file_docs) rst_corpus = rst_reader.corpus # TODO: change educe.corpus.Reader.slurp*() so that they return an object @@ -160,9 +337,17 @@ def main(args): # align EDUs with sentences, tokens and trees from PTB def open_plus(doc): - """Open and fully load a document + """Open and fully load a document. + + Parameters + ---------- + doc : educe.corpus.FileId + Document key. - doc is an educe.corpus.FileId + Returns + ------- + doc : DocumentPlus + Rich representation of the document. """ # create a DocumentPlus doc = rst_reader.decode(doc) @@ -192,54 +377,26 @@ def open_plus(doc): # to iterate over a stable (sorted) list of FileIds docs = [open_plus(doc) for doc in sorted(rst_corpus)] # instance generator - instance_generator = lambda doc: doc.all_edu_pairs() - split_feat_space = 'dir_sent' - # extract vectorized samples - if args.vocabulary is not None: - vocab = load_vocabulary(args.vocabulary) - vzer = DocumentCountVectorizer(instance_generator, - feature_set, - lecsie_data_dir=lecsie_data_dir, - vocabulary=vocab, - split_feat_space=split_feat_space) - X_gen = vzer.transform(docs) - else: - vzer = DocumentCountVectorizer(instance_generator, - feature_set, - lecsie_data_dir=lecsie_data_dir, - min_df=5, - split_feat_space=split_feat_space) - X_gen = vzer.fit_transform(docs) - - # extract class label for each instance - if live: - y_gen = itertools.repeat(0) - elif args.labels is not None: - labelset = load_labels(args.labels) - labtor = DocumentLabelExtractor(instance_generator, - labelset=labelset) - labtor.fit(docs) - y_gen = labtor.transform(docs) - else: - labtor = DocumentLabelExtractor(instance_generator) - # y_gen = labtor.fit_transform(rst_corpus) - # fit then transform enables to get classes_ for the dump - labtor.fit(docs) - y_gen = labtor.transform(docs) + ordered_pairs = not args.unordered_pairs # 2016-09-30 + if args.instances == 'same-unit': + # WIP 2016-07-08 pre-process to find same-units + instance_generator = ('same-unit', + lambda doc: doc.same_unit_candidates()) + split_feat_space = None + elif args.instances == 'edu-pairs': + # all pairs of EDUs + instance_generator = ('edu-pairs', + lambda doc: doc.all_edu_pairs( + ordered=ordered_pairs)) + split_feat_space = 'dir_sent' - # dump instances to files - if not os.path.exists(args.output): - os.makedirs(args.output) - # data file - of_ext = '.sparse' - if live: - out_file = os.path.join(args.output, 'extracted-features' + of_ext) - else: - of_bn = os.path.join(args.output, os.path.basename(args.corpus)) - out_file = '{}.relations{}'.format(of_bn, of_ext) - # dump EDUs and features in svmlight format - dump_all(X_gen, y_gen, out_file, labtor.labelset_, docs, - instance_generator) - # dump vocabulary - vocab_file = out_file + '.vocab' - dump_vocabulary(vzer.vocabulary_, vocab_file) + # do the extraction + extract_dump_instances(docs, instance_generator, feature_set, + lecsie_data_dir, + args.vocabulary, + split_feat_space, + args.labels, + live, + ordered_pairs, + args.output, + args.corpus) diff --git a/educe/rst_dt/learning/doc_vectorizer.py b/educe/rst_dt/learning/doc_vectorizer.py index 7dc496c..073daf3 100644 --- a/educe/rst_dt/learning/doc_vectorizer.py +++ b/educe/rst_dt/learning/doc_vectorizer.py @@ -290,41 +290,44 @@ def _extract_feature_vectors(self, doc): sf_cache = dict() for edu1, edu2 in edu_pairs: + edu1_num = edu1.num + edu2_num = edu2.num # WIP interval - if edu1.num < edu2.num: - edul_num = edu1.num - edur_num = edu2.num + if edu1_num < edu2_num: + edul_num = edu1_num + edur_num = edu2_num else: - edul_num = edu2.num - edur_num = edu1.num + edul_num = edu2_num + edur_num = edu1_num bwn_nums = range(edul_num + 1, edur_num) # end WIP interval feat_dict = dict() # retrieve info for each EDU - edu_info1 = edu_infos[edu1.num] - edu_info2 = edu_infos[edu2.num] + edu_info1 = edu_infos[edu1_num] + edu_info2 = edu_infos[edu2_num] # NEW paragraph info try: - para_info1 = para_infos[edu2para[edu1.num]] + para_info1 = para_infos[edu2para[edu1_num]] except TypeError: para_info1 = None try: - para_info2 = para_infos[edu2para[edu2.num]] + para_info2 = para_infos[edu2para[edu2_num]] except TypeError: para_info2 = None # ... and for the EDUs in between (WIP interval) - edu_info_bwn = [edu_infos[x] for x in bwn_nums] + edu_info_bwn = [edu_infos[i] for i in bwn_nums] + # gov EDU - if edu1.num not in sf_cache: - sf_cache[edu1.num] = dict(sing_extract( + if edu1_num not in sf_cache: + sf_cache[edu1_num] = dict(sing_extract( doc, edu_info1, para_info1)) - feat_dict['EDU1'] = dict(sf_cache[edu1.num]) + feat_dict['EDU1'] = dict(sf_cache[edu1_num]) # dep EDU - if edu2.num not in sf_cache: - sf_cache[edu2.num] = dict(sing_extract( + if edu2_num not in sf_cache: + sf_cache[edu2_num] = dict(sing_extract( doc, edu_info2, para_info2)) - feat_dict['EDU2'] = dict(sf_cache[edu2.num]) + feat_dict['EDU2'] = dict(sf_cache[edu2_num]) # pair + in between feat_dict['pair'] = dict(pair_extract( doc, edu_info1, edu_info2, edu_info_bwn)) diff --git a/educe/rst_dt/learning/features_dev.py b/educe/rst_dt/learning/features_dev.py index 5be8a3f..c0bff13 100644 --- a/educe/rst_dt/learning/features_dev.py +++ b/educe/rst_dt/learning/features_dev.py @@ -10,8 +10,8 @@ import numpy as np -from .base import DocumentPlusPreprocessor from educe.ptb.annotation import strip_punctuation, syntactic_node_seq +from educe.rst_dt.learning.base import DocumentPlusPreprocessor from educe.rst_dt.lecsie import (load_lecsie_feats, LINE_FORMAT as LECSIE_LINE_FORMAT) from educe.stac.lexicon.pdtb_markers import (load_pdtb_markers_lexicon, @@ -237,6 +237,17 @@ def extract_single_sentence(doc, edu_info, para_info): except KeyError: pass + +def extract_single_para(doc, edu_info, para_info): + """paragraph features for the EDU""" + # position of DU in paragraph + try: + offset_para = du_info[0]['edu_idx_in_para'] + if offset_para is not None: + yield ('num_edus_from_para_start', offset_para) + except KeyError: + pass + try: rev_offset_para = edu_info['edu_rev_idx_in_para'] if rev_offset_para is not None: @@ -244,9 +255,6 @@ def extract_single_sentence(doc, edu_info, para_info): except KeyError: pass - -def extract_single_para(doc, edu_info, para_info): - """paragraph features for the EDU""" # position of paragraph in doc # * from beginning try: @@ -566,7 +574,8 @@ def extract_pair_sent(doc, edu_info1, edu_info2, edu_info_bwn): # abs_dist does not seem to work well for inter-sent # rel dist - yield ('dist_sent', sent_id1 - sent_id2) + dist_sent = sent_id1 - sent_id2 + yield ('dist_sent', dist_sent) # L/R booleans if sent_id1 < sent_id2: # right attachment (gov < dep) @@ -574,15 +583,16 @@ def extract_pair_sent(doc, edu_info1, edu_info2, edu_info_bwn): elif sent_id1 > sent_id2: # left attachment yield ('sent_left', True) - yield ('sentence_id_diff_div3', (sent_id1 - sent_id2) / 3) + yield ('sentence_id_diff_div3', dist_sent / 3) # offset features offset1 = edu_info1['edu_idx_in_sent'] offset2 = edu_info2['edu_idx_in_sent'] if offset1 is not None and offset2 is not None: # offset diff - yield ('offset_diff', offset1 - offset2) - yield ('offset_diff_div3', (offset1 - offset2) / 3) + offset_diff = offset1 - offset2 + yield ('offset_diff', offset_diff) + yield ('offset_diff_div3', offset_diff / 3) # offset pair yield ('offset_div3_pair', (offset1 / 3, offset2 / 3)) @@ -590,8 +600,9 @@ def extract_pair_sent(doc, edu_info1, edu_info2, edu_info_bwn): rev_offset1 = edu_info1['edu_rev_idx_in_sent'] rev_offset2 = edu_info2['edu_rev_idx_in_sent'] if rev_offset1 is not None and rev_offset2 is not None: - yield ('rev_offset_diff', rev_offset1 - rev_offset2) - yield ('rev_offset_diff_div3', (rev_offset1 - rev_offset2) / 3) + rev_offset_diff = rev_offset1 - rev_offset2 + yield ('rev_offset_diff', rev_offset_diff) + yield ('rev_offset_diff_div3', rev_offset_diff / 3) yield ('rev_offset_div3_pair', (rev_offset1 / 3, rev_offset2 / 3)) # revSentenceID @@ -618,6 +629,9 @@ def extract_pair_syntax(doc, edu_info1, edu_info2, edu_info_bwn): edu1 = edu_info1['edu'] edu2 = edu_info2['edu'] + # nums + edu1_num = edu1.num + edu2_num = edu2.num # determine the linear order of {EDU_1, EDU_2} if edu1.num < edu2.num: @@ -631,10 +645,11 @@ def extract_pair_syntax(doc, edu_info1, edu_info2, edu_info_bwn): edu_info_l = edu_info2 edu_info_r = edu_info1 - # intra-sentential case only if tree_idx1 == tree_idx2: - ptree = doc.tkd_trees[tree_idx1] - pheads = doc.lex_heads[tree_idx1] + # intra-sentential + tree_idx = tree_idx1 + ptree = doc.tkd_trees[tree_idx] + pheads = doc.lex_heads[tree_idx] # * DS-LST features # find the head node of EDU1 diff --git a/educe/stac/learning/cmd/extract.py b/educe/stac/learning/cmd/extract.py index e64a00b..01c858a 100755 --- a/educe/stac/learning/cmd/extract.py +++ b/educe/stac/learning/cmd/extract.py @@ -98,23 +98,31 @@ def main_single(args): labtor = DialogueActVectorizer(instance_generator, DIALOGUE_ACTS) y_gen = labtor.transform(dialogues) - # create directory structure - if not fp.exists(args.output): - os.makedirs(args.output) - # these paths should go away once we switch to a proper dumper - out_file = fp.join(args.output, - fp.basename(args.corpus) + '.dialogue-acts.sparse') + # create directory structure: {output}/ + outdir = args.output + if not fp.exists(outdir): + os.makedirs(outdir) + + corpus_name = fp.basename(args.corpus) # list dialogue acts comment = labels_comment(labtor.labelset_) # dump: EDUs, pairings, vectorized pairings with label + # these paths should go away once we switch to a proper dumper + out_file = fp.join(outdir, + corpus_name + '.dialogue-acts.sparse') edu_input_file = out_file + '.edu_input' dump_edu_input_file(dialogues, edu_input_file) dump_svmlight_file(X_gen, y_gen, out_file, comment=comment) # dump vocabulary - vocab_file = out_file + '.vocab' + # WIP 2017-01-11 we might need to insert ".{instance_descr}", + # with e.g. instance_descr='edus', before ".sparse", so as to match + # the naming scheme currently used for RST + vocab_file = fp.join(outdir, + '{corpus_name}.dialogue-acts.sparse.vocab'.format( + corpus_name=corpus_name)) dump_vocabulary(vzer.vocabulary_, vocab_file) @@ -143,16 +151,23 @@ def main_pairs(args): y_gen = labtor.transform(dialogues) # create directory structure - if not fp.exists(args.output): - os.makedirs(args.output) + outdir = args.output + if not fp.exists(outdir): + os.makedirs(outdir) + + corpus_name = fp.basename(args.corpus) + # these paths should go away once we switch to a proper dumper - out_file = fp.join(args.output, - fp.basename(args.corpus) + '.relations.sparse') + out_file = fp.join(outdir, + corpus_name + '.relations.sparse') dump_all(X_gen, y_gen, out_file, labtor.labelset_, dialogues, instance_generator) + # dump vocabulary - vocab_file = out_file + '.vocab' + vocab_file = fp.join(outdir, + '{corpus_name}.relations.sparse.vocab'.format( + corpus_name=corpus_name)) dump_vocabulary(vzer.vocabulary_, vocab_file) From 0f7d67872eb685c94f208a658d3ebe509bd16047 Mon Sep 17 00:00:00 2001 From: moreymat Date: Thu, 2 Feb 2017 15:19:19 +0100 Subject: [PATCH 17/44] WIP document-centric feature extraction, part 3 --- educe/learning/edu_input_format.py | 137 +++++++++++++++++++++----- educe/learning/keygroup_vectorizer.py | 70 ++++++++++--- educe/learning/vocabulary_format.py | 1 + educe/stac/learning/cmd/extract.py | 60 ++++++++--- 4 files changed, 221 insertions(+), 47 deletions(-) diff --git a/educe/learning/edu_input_format.py b/educe/learning/edu_input_format.py index 4d6839d..22babfe 100644 --- a/educe/learning/edu_input_format.py +++ b/educe/learning/edu_input_format.py @@ -10,10 +10,16 @@ import six from .svmlight_format import dump_svmlight_file +# WIP load_edu_input_file +# FIXME adapt to STAC +from educe.annotation import Span +from educe.corpus import FileId +from educe.rst_dt.annotation import EDU as RstEDU # pylint: disable=invalid-name # a lot of the names here are chosen deliberately to -# go with scikit convention +# go with sklearn convention + # EDUs def _dump_edu_input_file(docs, f): @@ -67,6 +73,69 @@ def dump_edu_input_file(docs, f): _dump_edu_input_file(docs, f) +# FIXME adapt to STAC +def _load_edu_input_file(f, edu_type): + """Do load.""" + edus = [] + edu2sent = [] + + if edu_type == 'rst-dt': + EDU = RstEDU + # FIXME support STAC + + reader = csv.reader(f, dialect=csv.excel_tab) + for line in reader: + if not line: + continue + edu_gid, edu_txt, grouping, subgroup, edu_start, edu_end = line + # FIXME only works for RST-DT, broken on STAC + # no subdoc in RST-DT, hence no orig_subdoc in global_id for EDU + orig_doc, edu_lid = edu_gid.rsplit('_', 1) + assert grouping == orig_doc # both are the doc_name + origin = FileId(orig_doc, None, None, None) + edu_num = int(edu_lid) + edu_txt = edu_txt.decode('utf-8') + edu_start = int(edu_start) + edu_end = int(edu_end) + edu_span = Span(edu_start, edu_end) + edus.append( + EDU(edu_num, edu_span, edu_txt, origin=origin) + ) + # edu2sent + sent_idx = int(subgroup.split('_sent')[1]) + edu2sent.append(sent_idx) + return {'filename': f.name, + 'edus': edus, + 'edu2sent': edu2sent} + + +def load_edu_input_file(f, edu_type='rst-dt'): + """Load a list of EDUs from a file in the EDU input format. + + Parameters + ---------- + f : str + Path to the .edu_input file + + edu_type : str, one of {'rst-dt'} + Type of EDU to load ; 'rst-dt' is the only type currently + allowed but more should come (unless a unifying type for EDUs + emerge, rendering this parameter useless). + + Returns + ------- + data: dict + Bunch-like object with interesting fields "filename", "edus", + "edu2sent". + """ + if edu_type != 'rst-dt': + raise NotImplementedError( + "edu_type {} not yet implemented".format(edu_type)) + with codecs.open(f, 'rb', 'utf-8') as f: + return _load_edu_input_file(f, edu_type) +# end FIXME adapt to STAC + + # pairings def _dump_pairings_file(docs_epairs, f): """Actually do dump""" @@ -103,42 +172,64 @@ def labels_comment(class_mapping): def _load_labels(f): """Actually read the label set""" - line = f.readline() - seq = line[1:].split()[1:] - labels = {lbl: idx for idx, lbl in enumerate(seq, start=1)} - labels['__UNK__'] = 0 + labels = dict() + for line in f: + i, lbl = line.strip().split() + labels[lbl] = int(i) + assert labels['__UNK__'] == 0 return labels def load_labels(f): - """Read label set (from a features file) into a dictionary mapping labels - to indices and index""" + """Read label set into a dictionary mapping labels to indices""" with codecs.open(f, 'r', 'utf-8') as f: return _load_labels(f) -def dump_all(X_gen, y_gen, f, class_mapping, docs, instance_generator): - """Dump a whole dataset: features (in svmlight) and EDU pairs +def _dump_labels(labelset, f): + """Do dump labels""" + for lbl, i in sorted(labelset.items(), key=lambda x: x[1]): + f.write('{}\t{}\n'.format(i, lbl)) - class_mapping is a mapping from label to int - :type X_gen: iterable of int arrays - :type y_gen: iterable of int - :param f: output features file path - :param class_mapping: dict(string, int) - :param instance_generator: function that returns an iterable - of pairs given a document +def dump_labels(labelset, f): + """Dump labelset as a mapping from label to index. + + Parameters + ---------- + labelset: dict(label, int) + Mapping from label to index. """ - # the labelset will be written in a comment at the beginning of the - # svmlight file - comment = labels_comment(class_mapping) + with codecs.open(f, 'wb', 'utf-8') as f: + _dump_labels(labelset, f) + + +def dump_all(X_gen, y_gen, f, docs, instance_generator): + """Dump a whole dataset: features (in svmlight) and EDU pairs. - # dump: EDUs, pairings, vectorized pairings with label + Parameters + ---------- + X_gen : iterable of iterable of int arrays + Feature vectors. + + y_gen : iterable of iterable of int + Ground truth labels. + + f : str + Output features file path + + docs : list of DocumentPlus + Documents + + instance_generator : function from doc to iterable of pairs + TODO + """ + # dump EDUs edu_input_file = f + '.edu_input' dump_edu_input_file(docs, edu_input_file) - + # dump EDU pairings pairings_file = f + '.pairings' dump_pairings_file((instance_generator(doc) for doc in docs), pairings_file) - - dump_svmlight_file(X_gen, y_gen, f, comment=comment) + # dump vectorized pairings with label + dump_svmlight_file(X_gen, y_gen, f) diff --git a/educe/learning/keygroup_vectorizer.py b/educe/learning/keygroup_vectorizer.py index d609779..ae6735d 100644 --- a/educe/learning/keygroup_vectorizer.py +++ b/educe/learning/keygroup_vectorizer.py @@ -5,16 +5,42 @@ # lots of scikit-conventional names here from collections import defaultdict +import sys + +import numpy as np class KeyGroupVectorizer(object): """Transforms lists of KeyGroups to sparse vectors. + + Attributes + ---------- + vocabulary_ : dict(str, int) + Vocabulary mapping. """ def __init__(self): - self.vocabulary_ = None + self.vocabulary_ = None # FIXME should be set in fit() + + def _count_vocab(self, vectors, fixed_vocab=False): + """Create sparse feature matrices and shared vocabulary. + + Parameters + ---------- + vectors : list of list of KeyGroup + List of feature matrices, one list per doc, one line per + sample. + + fixed_vocab : boolean, defaults to False + If True, use the vocabulary that hopefully has already been + set during `fit()`. + + Returns + ------- + vocabulary : dict(str, int) + Mapping from features to integers. - def _count_vocab(self, vectors, fixed_vocab): - """Create sparse feature matrix and vocabulary + X : list of list of list of tuple(int, float) + List of feature matrices. """ if fixed_vocab: vocabulary = self.vocabulary_ @@ -29,15 +55,22 @@ def _count_vocab(self, vectors, fixed_vocab): # begins row_ptr = [] row_ptr.append(0) + doc_ptr = [] + doc_ptr.append(0) - for vec in vectors: - for feature, featval in vec.one_hot_values_gen(): - try: - feature_acc.append((vocabulary[feature], featval)) - except KeyError: - # ignore unknown features if fixed vocab - continue - row_ptr.append(len(feature_acc)) + print('fit vocab') # DEBUG + for vecs in vectors: + for vec in vecs: + for feature, featval in vec.one_hot_values_gen(): + try: + feature_idx = vocabulary[feature] + feature_acc.append((feature_idx, featval)) + except KeyError: + # ignore unknown features if fixed vocab + continue + row_ptr.append(len(feature_acc)) + doc_ptr.append(row_ptr[-1]) + print('vocab done') # DEBUG if not fixed_vocab: vocabulary = dict(vocabulary) @@ -46,10 +79,23 @@ def _count_vocab(self, vectors, fixed_vocab): # build a feature count matrix out of feature_acc and row_ptr X = [] + doc_nxt = 0 for i in xrange(len(row_ptr) - 1): current_row, next_row = row_ptr[i], row_ptr[i + 1] + if current_row == doc_ptr[doc_nxt]: + # start a new doc matrix + X.append([]) + doc_nxt += 1 + print('doc ', str(i)) # DEBUG x = feature_acc[current_row:next_row] - X.append(x) + X[-1].append(x) + # DEBUG + n_edus = [len(y) for y in X] + print(len(vocabulary), sys.getsizeof(vocabulary)) + print(len(X), sum(len(y) for y in X), sys.getsizeof(X)) + print(sum(nb_edus * (nb_edus - 1) for nb_edus in n_edus)) + raise ValueError('woopti') + # end DEBUG return vocabulary, X def fit_transform(self, vectors): diff --git a/educe/learning/vocabulary_format.py b/educe/learning/vocabulary_format.py index 74a7aca..685d4c7 100644 --- a/educe/learning/vocabulary_format.py +++ b/educe/learning/vocabulary_format.py @@ -3,6 +3,7 @@ import codecs + def _dump_vocabulary(vocabulary, f): """Actually do dump""" line_pattern = u'{fn}\t{fx}\n' diff --git a/educe/stac/learning/cmd/extract.py b/educe/stac/learning/cmd/extract.py index 01c858a..2baa2f2 100755 --- a/educe/stac/learning/cmd/extract.py +++ b/educe/stac/learning/cmd/extract.py @@ -109,12 +109,33 @@ def main_single(args): comment = labels_comment(labtor.labelset_) # dump: EDUs, pairings, vectorized pairings with label - # these paths should go away once we switch to a proper dumper - out_file = fp.join(outdir, - corpus_name + '.dialogue-acts.sparse') - edu_input_file = out_file + '.edu_input' - dump_edu_input_file(dialogues, edu_input_file) - dump_svmlight_file(X_gen, y_gen, out_file, comment=comment) + # WIP switch to a document (here dialogue) centric generation of data + # 1. create a folder for the corpus: {output}/{corpus}/ + outdir_corpus = fp.join(outdir, corpus_name) + if not fp.exists(outdir_corpus): + os.makedirs(outdir_corpus) + # 2. dump edu_input and features files + if args.file_split == 'dialogue': + # one file per dialogue + for dia, X, y in itertools.izip(dialogues, X_gen, y_gen): + dia_id = dia.grouping + print('dump dialogue', dia_id) + # these paths should go away once we switch to a proper dumper + feat_file = fp.join(outdir_corpus, + '{dia_id}.dialogue-acts.sparse'.format( + dia_id=dia_id)) + edu_input_file = '{feat_file}.edu_input'.format(feat_file=feat_file) + dump_edu_input_file(dia, edu_input_file) + dump_svmlight_file(X, y, feat_file, comment=comment) + elif args.file_split == 'corpus': + # one file per corpus (in fact, corpus split) + # these paths should go away once we switch to a proper dumper + out_file = fp.join(outdir, + corpus_name + '.dialogue-acts.sparse') + edu_input_file = out_file + '.edu_input' + dump_edu_input_file(dialogues, edu_input_file) + dump_svmlight_file(X_gen, y_gen, out_file, comment=comment) + # end WIP # dump vocabulary # WIP 2017-01-11 we might need to insert ".{instance_descr}", @@ -157,12 +178,27 @@ def main_pairs(args): corpus_name = fp.basename(args.corpus) - # these paths should go away once we switch to a proper dumper - out_file = fp.join(outdir, - corpus_name + '.relations.sparse') - - dump_all(X_gen, y_gen, out_file, labtor.labelset_, dialogues, - instance_generator) + # WIP switch to a document (here dialogue) centric generation of data + outdir_corpus = fp.join(outdir, corpus_name) + if not fp.exists(outdir_corpus): + os.makedirs(outdir_corpus) + if args.file_split == 'dialogue': + for dia, X, y in itertools.izip(dialogues, X_gen, y_gen): + dia_id = dia.grouping + # these paths should go away once we switch to a proper dumper + out_file = fp.join(outdir_corpus, + '{dia_id}.relations.sparse'.format( + dia_id=dia_id)) + dump_all(X, y, out_file, dia, instance_generator) + elif args.file_split == 'corpus': + # one file per corpus (in fact corpus split) + # these paths should go away once we switch to a proper dumper + out_file = fp.join(outdir, + corpus_name + '.relations.sparse') + + dump_all(X_gen, y_gen, out_file, labtor.labelset_, dialogues, + instance_generator) + # end WIP # dump vocabulary vocab_file = fp.join(outdir, From 6b37d86d38998fa7cc85cef995f27d88c4918f8b Mon Sep 17 00:00:00 2001 From: moreymat Date: Thu, 2 Feb 2017 15:20:39 +0100 Subject: [PATCH 18/44] WIP rst_dt: fragmented EDUs --- educe/rst_dt/frag_edus.py | 45 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 educe/rst_dt/frag_edus.py diff --git a/educe/rst_dt/frag_edus.py b/educe/rst_dt/frag_edus.py new file mode 100644 index 0000000..b7c1144 --- /dev/null +++ b/educe/rst_dt/frag_edus.py @@ -0,0 +1,45 @@ +"""This module provides an API for fragmented EDUs.""" + +from __future__ import absolute_import, print_function + + +def edu_num(edu_id): + """Get the index (number) of an EDU from its identifier. + + A variant of this probably exists elsewhere in the code base, but I + can't seem to find it as of 2017-02-01. + + Parameters + ---------- + edu_id : str + Identifier of the EDU. + + Returns + ------- + edu_num : int + Position index of this EDU in the document. + """ + if edu_id == 'ROOT': + return 0 + return int(edu_id.rsplit('_', 1)[1]) + + +def edu_members(du): + """Get a tuple with the num of the EDUs members of a DU. + + Parameters + ---------- + du : EDU or :obj:`tuple` of str + Discourse Unit, either an EDU or a non-elementary DU described + by the tuple of the identifiers of its EDU members. + + Returns + ------- + mem_nums : :obj:`tuple` of int + Numbers of the EDU members of this DU. + """ + if isinstance(du, tuple): # frag EDU, CDU + # get the EDUs from their identifiers + return tuple(edu_num(x) for x in du[1]) + else: + return tuple([edu_num(du.identifier())]) From 207bbd5e68cd0e119e905e48ac9f3e3a785d28ce Mon Sep 17 00:00:00 2001 From: moreymat Date: Thu, 2 Feb 2017 15:21:23 +0100 Subject: [PATCH 19/44] WIP disdep format --- educe/learning/disdep_format.py | 112 ++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 educe/learning/disdep_format.py diff --git a/educe/learning/disdep_format.py b/educe/learning/disdep_format.py new file mode 100644 index 0000000..015523d --- /dev/null +++ b/educe/learning/disdep_format.py @@ -0,0 +1,112 @@ +"""Dependency format for RST discourse trees. + +One line per EDU. +""" + +from __future__ import absolute_import, print_function +import codecs +import csv +import os + +from educe.rst_dt.corpus import (RELMAP_112_18_FILE, RstRelationConverter, + Reader) +from educe.rst_dt.deptree import RstDepTree +from educe.rst_dt.rst_wsj_corpus import TRAIN_FOLDER, TEST_FOLDER + +RELCONV = RstRelationConverter(RELMAP_112_18_FILE).convert_label + + +def _dump_disdep_file(rst_deptree, f): + """Actually do dump""" + writer = csv.writer(f, dialect=csv.excel_tab) + + # 0 is the fake root, there is no point in writing its info + edus = rst_deptree.edus[1:] + heads = rst_deptree.heads[1:] + labels = rst_deptree.labels[1:] + nucs = rst_deptree.nucs[1:] + ranks = rst_deptree.ranks[1:] + + for i, (edu, head, label, nuc, rank) in enumerate( + zip(edus, heads, labels, nucs, ranks), start=1): + # text of EDU ; some EDUs have newlines in their text, so convert + # those to simple spaces + txt = edu.text().replace('\n', ' ') + clabel = RELCONV(label) + writer.writerow([i, txt, head, label, clabel, nuc, rank]) + + +def dump_disdep_file(rst_deptree, f): + """Dump dependency RST tree to a disdep file. + + Parameters + ---------- + doc: DocumentPlus + (Rich representation of) the document. + f: str + Path of the output file. + """ + with codecs.open(f, 'wb', 'utf-8') as f: + _dump_disdep_file(rst_deptree, f) + + +def dump_disdep_files(rst_deptrees, out_dir): + """Dump dependency RST trees to a folder. + + This creates one file per RST tree plus a metadata file (encoding of + n-ary relations, coarse-to-fine mapping for relation labels). + """ + if not os.path.exists(out_dir): + os.makedirs(out_dir) + + # metadata + nary_encs = [x.nary_enc for x in rst_deptrees] + assert len(set(nary_encs)) == 1 + nary_enc = nary_encs[0] + f_meta = os.path.join(out_dir, 'metadata') + with codecs.open(f_meta, mode='w', encoding='utf-8') as f_meta: + print('nary_enc: {}'.format(nary_enc), file=f_meta) + print('relmap: {}'.format(RELMAP_112_18_FILE), file=f_meta) + + for rst_deptree in rst_deptrees: + doc_name = rst_deptree.origin.doc + f_doc = os.path.join(out_dir, '{}.dis_dep'.format(doc_name)) + dump_disdep_file(rst_deptree, f_doc) + + +def main(): + """A main that should probably become a proper executable script""" + # TODO expose these parameters with an argparser + corpus_dir = os.path.join( + '/home/mmorey/corpora/rst_discourse_treebank/', + 'data' + ) + dir_train = os.path.join(corpus_dir, TRAIN_FOLDER) + dir_test = os.path.join(corpus_dir, TEST_FOLDER) + + out_dir = os.path.join( + '/home/mmorey/melodi/irit-rst-dt/TMP_disdep_chain_true' + ) + nary_enc = 'chain' # 'tree' + # end TODO + + # convert and dump RST trees from train + reader_train = Reader(dir_train) + trees_train = reader_train.slurp() + dtrees_train = {doc_name: RstDepTree.from_rst_tree(rst_tree, + nary_enc=nary_enc) + for doc_name, rst_tree in trees_train.items()} + dump_disdep_files(dtrees_train.values(), + os.path.join(out_dir, os.path.basename(dir_train))) + # convert and dump RST trees from test + reader_test = Reader(dir_test) + trees_test = reader_test.slurp() + dtrees_test = {doc_name: RstDepTree.from_rst_tree(rst_tree, + nary_enc=nary_enc) + for doc_name, rst_tree in trees_test.items()} + dump_disdep_files(dtrees_test.values(), + os.path.join(out_dir, os.path.basename(dir_test))) + + +if __name__ == '__main__': + main() From 4ee8c03fb2ccda63a86220ff7d6209986826d70f Mon Sep 17 00:00:00 2001 From: moreymat Date: Thu, 2 Feb 2017 17:23:59 +0100 Subject: [PATCH 20/44] WIP feature extraction runs on RST-DT, file_split=corpus --- educe/learning/svmlight_format.py | 28 +++++------ educe/rst_dt/learning/cmd/extract.py | 67 ++++++++++++++++++++------- educe/rst_dt/learning/features_dev.py | 2 +- 3 files changed, 67 insertions(+), 30 deletions(-) diff --git a/educe/learning/svmlight_format.py b/educe/learning/svmlight_format.py index 67cad65..09f8293 100644 --- a/educe/learning/svmlight_format.py +++ b/educe/learning/svmlight_format.py @@ -19,19 +19,21 @@ def _dump_svmlight(X_gen, y_gen, f, comment): line_pattern = '{yi}' line_pattern += ' {s}\n' - for x, yi in itertools.izip(X_gen, y_gen): - # sort features by their index - x = sorted(x) - # zero values need not be written in the svmlight format - x = [(feat_id, feat_val) for feat_id, feat_val in x - if feat_val != 0] - # feature ids in libsvm are one-based, so feat_id + 1 - # TODO use unicode all along, then encode to ascii at the last - # possible moment (aka here), e.g. - # s = u' '.join(...) ; f.write(... .encode('ascii')) - s = ' '.join(value_pattern.format(fid=str(feat_id + 1), fv=feat_val) - for feat_id, feat_val in x) - f.write(line_pattern.format(yi=yi, s=s)) + for X, y in itertools.izip(X_gen, y_gen): + for x, yi in itertools.izip(X, y): + # sort features by their index + x = sorted(x) + # zero values need not be written in the svmlight format + x = [(feat_id, feat_val) for feat_id, feat_val in x + if feat_val != 0] + # feature ids in libsvm are one-based, so feat_id + 1 + # TODO use unicode all along, then encode to ascii at the last + # possible moment (aka here), e.g. + # s = u' '.join(...) ; f.write(... .encode('ascii')) + s = ' '.join(value_pattern.format(fid=str(feat_id + 1), + fv=feat_val) + for feat_id, feat_val in x) + f.write(line_pattern.format(yi=yi, s=s)) def dump_svmlight_file(X_gen, y_gen, f, zero_based=True, comment=None, diff --git a/educe/rst_dt/learning/cmd/extract.py b/educe/rst_dt/learning/cmd/extract.py index 383fda5..b56df66 100644 --- a/educe/rst_dt/learning/cmd/extract.py +++ b/educe/rst_dt/learning/cmd/extract.py @@ -22,7 +22,6 @@ import educe.stac import educe.util -from educe.learning.cdu_input_format import dump_all_cdus from educe.learning.edu_input_format import (dump_all, dump_labels, load_labels) from educe.learning.vocabulary_format import (dump_vocabulary, @@ -116,6 +115,16 @@ def config_argparser(parser): action='store_true', help=("Instances are unordered pairs: " "(src, tgt) == (tgt, src)")) + # WIP 2017-02-02 toggle between corpus- and doc-centric feature + # extraction + parser.add_argument('--file_split', + choices=['doc', 'corpus'], + default='corpus', + help=("Level of granularity for each set of" + "files: 'doc' produces one set of files per" + "document ; 'corpus' one set of files per" + "corpus split (e.g. 'train', 'test')")) + # end WIP toggle between corpus- and doc-centric feature extraction parser.set_defaults(func=main) @@ -126,7 +135,8 @@ def config_argparser(parser): def extract_dump_instances(docs, instance_generator, feature_set, lecsie_data_dir, vocabulary, split_feat_space, labels, - live, ordered_pairs, output, corpus): + live, ordered_pairs, output, corpus, + file_split='corpus'): """Extract and dump instances. Parameters @@ -167,11 +177,17 @@ def extract_dump_instances(docs, instance_generator, feature_set, # setup persistency if not os.path.exists(output): os.makedirs(output) + + corpus_name = os.path.basename(corpus) + if live: - fn_out = 'extracted-features.{}'.format(instance_descr) + fn_out = 'extracted-features.{instance_descr}'.format( + instance_descr=instance_descr) else: - fn_out = '{}.relations.{}'.format( - os.path.basename(corpus), instance_descr) + fn_out = '{corpus_name}.relations.{instance_descr}'.format( + corpus_name=corpus_name, + instance_descr=instance_descr) + # vocabulary, labels fn_ext = '.sparse' # our extension for sparse datasets (svmlight) vocab_file = os.path.join(output, fn_out + fn_ext + '.vocab') @@ -226,19 +242,37 @@ def extract_dump_instances(docs, instance_generator, feature_set, y_gen = labtor.transform(docs) # dump instances to files - for doc, X, y in itertools.izip(docs, X_gen, y_gen): - # dump EDUs and features in svmlight format - doc_name = doc.key.doc - # TODO refactor + if file_split == 'doc': + # one set of files per document + for doc, X, y in itertools.izip(docs, X_gen, y_gen): + # dump EDUs and features in svmlight format + doc_name = doc.key.doc + # TODO refactor + if live: + fn_out = 'extracted-features.{}{}'.format( + instance_descr, fn_ext) + else: + fn_out = '{}.relations.{}{}'.format( + doc_name, instance_descr, fn_ext) + out_file = os.path.join(out_dir, fn_out) + # end TODO refactor + dump_all([X], [y], out_file, [doc], instance_gen) + elif file_split == 'corpus': + # one set of files per corpus (in fact, corpus split) if live: - fn_out = 'extracted-features.{}{}'.format( - instance_descr, fn_ext) + fn_out = 'extracted-features.{instance_descr}{fn_ext}'.format( + instance_descr=instance_descr, fn_ext=fn_ext) else: - fn_out = '{}.relations.{}{}'.format( - doc_name, instance_descr, fn_ext) + fn_out = ('{corpus_name}.relations.{instance_descr}{fn_ext}' + .format( + corpus_name=corpus_name, + instance_descr=instance_descr, + fn_ext=fn_ext)) out_file = os.path.join(out_dir, fn_out) - # end TODO refactor - dump_all(X, y, out_file, doc, instance_gen) + dump_all(X_gen, y_gen, out_file, docs, instance_gen) + else: + raise ValueError('Unknown value for args.file_split : {}'.format( + args.file_split)) # dump labelset if labels is not None: @@ -399,4 +433,5 @@ def open_plus(doc): live, ordered_pairs, args.output, - args.corpus) + args.corpus, + file_split=args.file_split) diff --git a/educe/rst_dt/learning/features_dev.py b/educe/rst_dt/learning/features_dev.py index c0bff13..11cd65a 100644 --- a/educe/rst_dt/learning/features_dev.py +++ b/educe/rst_dt/learning/features_dev.py @@ -242,7 +242,7 @@ def extract_single_para(doc, edu_info, para_info): """paragraph features for the EDU""" # position of DU in paragraph try: - offset_para = du_info[0]['edu_idx_in_para'] + offset_para = edu_info['edu_idx_in_para'] if offset_para is not None: yield ('num_edus_from_para_start', offset_para) except KeyError: From 7cfc500ad74732c62bad6441ed6196b393034826 Mon Sep 17 00:00:00 2001 From: moreymat Date: Thu, 9 Feb 2017 11:53:24 +0100 Subject: [PATCH 21/44] FIX from/to SimpleRSTTree: nuc moved up too --- educe/rst_dt/dep2con.py | 123 ++++++++++++++++------------------------ educe/rst_dt/deptree.py | 2 +- educe/rst_dt/parse.py | 13 +++-- 3 files changed, 57 insertions(+), 81 deletions(-) diff --git a/educe/rst_dt/dep2con.py b/educe/rst_dt/dep2con.py index 65ebeea..2c6b721 100644 --- a/educe/rst_dt/dep2con.py +++ b/educe/rst_dt/dep2con.py @@ -593,59 +593,6 @@ def deptree_to_simple_rst_tree(dtree, allow_forest=False): (and so on, until all we have left is a single RST tree). """ - - def mk_leaf(edu): - """ - Trivial partial tree for use when processing dependency - tree leaves - """ - return TreeParts(edu=edu, - edu_span=(edu.num, edu.num), - span=edu.text_span(), - rel="leaf", - kids=[]) - - def parts_to_tree(nuclearity, parts): - """ - Combine root nuclearity information with a partial tree - to form a full RST `SimpleTree` - """ - node = Node(nuclearity, - parts.edu_span, - parts.span, - parts.rel) - kids = parts.kids or [parts.edu] - return SimpleRSTTree(node, kids) - - def connect_trees(src, tgt, rel, nuc): - """ - Return a partial tree, assigning order and nuclearity to - child trees - """ - tgt_nuc = nuc - - if src.span.overlaps(tgt.span): - raise RstDtException("Span %s overlaps with %s " % - (src.span, tgt.span)) - elif src.span <= tgt.span: - left = parts_to_tree(NUC_N, src) - right = parts_to_tree(tgt_nuc, tgt) - else: - left = parts_to_tree(tgt_nuc, tgt) - right = parts_to_tree(NUC_N, src) - - l_edu_span = treenode(left).edu_span - r_edu_span = treenode(right).edu_span - - edu_span = (min(l_edu_span[0], r_edu_span[0]), - max(l_edu_span[1], r_edu_span[1])) - res = TreeParts(edu=src.edu, - edu_span=edu_span, - span=src.span.merge(tgt.span), - rel=rel, - kids=[left, right]) - return res - def walk(ancestor, subtree): """ The basic descent/ascent driver of our conversion algorithm. @@ -674,28 +621,67 @@ def walk(ancestor, subtree): Parameters ---------- - ancestor: TreeParts - TreeParts of the ancestor + ancestor : SimpleRSTTree + SimpleRSTTree of the ancestor - subtree: int + subtree : int Index of the head of the subtree Returns ------- - res: TreeParts + res : SimpleRSTTree + SimpleRSTTree covering ancestor and subtree. """ - rel = dtree.labels[subtree] - nuc = dtree.nucs[subtree] + # create tree leaf for src + edu_src = dtree.edus[subtree] + src = SimpleRSTTree( + Node("leaf", (edu_src.num, edu_src.num), edu_src.text_span(), + "leaf"), + [edu_src]) - src = mk_leaf(dtree.edus[subtree]) # descend into each child, but note that we are folding # rather than mapping, ie. we threading along a nested # RST tree as go from sibling to sibling ranked_targets = dtree.deps(subtree) for tgt in ranked_targets: src = walk(src, tgt) - # ancestor is None in the case of the root node - return connect_trees(ancestor, src, rel, nuc) if ancestor else src + + if not ancestor: + # ancestor is None in the case of the root node + return src + + # connect ancestor with src + n_anc = treenode(ancestor) + n_src = treenode(src) + rel = dtree.labels[subtree] + nuc = dtree.nucs[subtree] + # + if n_anc.span.overlaps(n_src.span): + raise RstDtException("Span %s overlaps with %s " % + (n_anc.span, n_src.span)) + else: + if n_anc.span <= n_src.span: + left = ancestor + right = src + nuc_kids = [NUC_N, nuc] + else: + left = src + right = ancestor + nuc_kids = [nuc, NUC_N] + # nuc in SimpleRSTTree is the concatenation of the initial + # letter of each kid's nuclearity for the relation, + # eg. {NS, SN, NN} + nuc = ''.join(x[0] for x in nuc_kids) + # compute EDU span of the parent node from the kids' + l_edu_span = treenode(left).edu_span + r_edu_span = treenode(right).edu_span + edu_span = (min(l_edu_span[0], r_edu_span[0]), + max(l_edu_span[1], r_edu_span[1])) + txt_span = n_anc.span.merge(n_src.span) + res = SimpleRSTTree( + Node(nuc, edu_span, txt_span, rel), + [left, right]) + return res roots = dtree.real_roots_idx() if not allow_forest and len(roots) > 1: @@ -705,8 +691,7 @@ def walk(ancestor, subtree): srtrees = [] for real_root in roots: - rparts = walk(None, real_root) - srtree = parts_to_tree(NUC_R, rparts) + srtree = walk(None, real_root) srtrees.append(srtree) # for the most common case, return the tree @@ -718,16 +703,6 @@ def walk(ancestor, subtree): return srtrees -# pylint: disable=R0903, W0232 -class TreeParts(namedtuple("TreeParts_", "edu edu_span span rel kids")): - """ - Partially built RST tree when converting from dependency tree - Kids here is nuclearity-annotated children - """ - pass -# pylint: enable=R0903, W0232 - - def deptree_to_rst_tree(dtree): """Create an RSTTree from an RstDepTree. diff --git a/educe/rst_dt/deptree.py b/educe/rst_dt/deptree.py index ab1347e..214fe85 100644 --- a/educe/rst_dt/deptree.py +++ b/educe/rst_dt/deptree.py @@ -352,7 +352,7 @@ def walk(tree): rel = treenode(tree).rel left = tree[0] right = tree[1] - nscode = "".join(treenode(kid).nuclearity[0] for kid in tree) + nscode = treenode(tree).nuclearity lhead = walk(left) rhead = walk(right) diff --git a/educe/rst_dt/parse.py b/educe/rst_dt/parse.py index a3428aa..2b2c3dc 100644 --- a/educe/rst_dt/parse.py +++ b/educe/rst_dt/parse.py @@ -240,8 +240,6 @@ def parse_lightweight_tree(tstr): examples """ _lw_type_re = re.compile(r'(?P[RSN])(:(?P.*)|$)') - _lw_nuc_map = dict((nuc[0], nuc) - for nuc in ["Root", "Nucleus", "Satellite"]) # pylint: disable=C0103 PosInfo = collections.namedtuple("PosInfo", "text edu") # pylint: enable=C0103 @@ -254,25 +252,28 @@ def walk(subtree, posinfo=PosInfo(text=0, edu=0)): if isinstance(subtree, Tree): start = copy.copy(posinfo) children = [] + nuc_kids = [] for kid in subtree: - tree, posinfo = walk(kid, posinfo) + tree, posinfo, nuc_kid = walk(kid, posinfo) children.append(tree) + nuc_kids.append(nuc_kid) + nuclearity = ''.join(x for x in nuc_kids) match = _lw_type_re.match(treenode(subtree)) if not match: raise RSTTreeException( "Missing nuclearity annotation in " + str(subtree)) - nuclearity = _lw_nuc_map[match.group("nuc")] + nuc = match.group("nuc") rel = match.group("rel") or "leaf" edu_span = (start.edu, posinfo.edu - 1) span = Span(start.text, posinfo.text) node = Node(nuclearity, edu_span, span, rel) - return SimpleRSTTree(node, children), posinfo + return SimpleRSTTree(node, children), posinfo, nuc else: text = subtree start = posinfo.text end = start + len(text) posinfo2 = PosInfo(text=end, edu=posinfo.edu+1) - return EDU(posinfo.edu, Span(start, end), text), posinfo2 + return EDU(posinfo.edu, Span(start, end), text), posinfo2, "leaf" return walk(Tree.fromstring(tstr))[0] From c689eb5516cc932a9fee237c07186bc4b66145b7 Mon Sep 17 00:00:00 2001 From: moreymat Date: Fri, 10 Feb 2017 14:30:14 +0100 Subject: [PATCH 22/44] DOC fix a few docstrings --- educe/rst_dt/annotation.py | 23 ++++++++++++++++++----- educe/rst_dt/dep2con.py | 4 ++-- educe/rst_dt/deptree.py | 10 +++++++++- 3 files changed, 29 insertions(+), 8 deletions(-) diff --git a/educe/rst_dt/annotation.py b/educe/rst_dt/annotation.py index 9e4071b..e992cf6 100644 --- a/educe/rst_dt/annotation.py +++ b/educe/rst_dt/annotation.py @@ -115,6 +115,11 @@ def __init__(self, num, span, text, context=None, origin=None): def set_origin(self, origin): """ Update the origin of this annotation and any contained within + + Parameters + ---------- + origin : FileId + File identifier of the origin of this annotation. """ self.origin = origin @@ -297,8 +302,12 @@ def __init__(self, node, children, origin=None, verbose=False): # end WIP head def set_origin(self, origin): - """ - Update the origin of this annotation and any contained within + """Update the origin of this annotation and any contained within + + Parameters + ---------- + origin : FileId + File identifier of the origin of this annotation. """ self.origin = origin for child in self: @@ -465,9 +474,13 @@ def __init__(self, node, children, origin=None): # end WIP head def set_origin(self, origin): - """ - Recursively update the origin for this annotation, ie. - a little link to the document metadata for this annotation + """Recursively update the origin for this annotation, ie. + a little link to the document metadata for this annotation. + + Parameters + ---------- + origin : FileId + File identifier of the origin of this annotation. """ self.origin = origin for child in self: diff --git a/educe/rst_dt/dep2con.py b/educe/rst_dt/dep2con.py index 2c6b721..e24eabd 100644 --- a/educe/rst_dt/dep2con.py +++ b/educe/rst_dt/dep2con.py @@ -645,9 +645,9 @@ def walk(ancestor, subtree): ranked_targets = dtree.deps(subtree) for tgt in ranked_targets: src = walk(src, tgt) - if not ancestor: - # ancestor is None in the case of the root node + # first call: ancestor is None, subtree is the index of the + # (presumably unique) real root return src # connect ancestor with src diff --git a/educe/rst_dt/deptree.py b/educe/rst_dt/deptree.py index 214fe85..4a43ac1 100644 --- a/educe/rst_dt/deptree.py +++ b/educe/rst_dt/deptree.py @@ -83,8 +83,10 @@ class RstDepTree(object): ---------- edus : list of EDU List of the EDUs of this document. + origin : Document?, optional TODO + nary_enc : one of {'chain', 'tree'}, optional Type of encoding used for n-ary relations: 'chain' or 'tree'. This determines for example how fragmented EDUs are resolved. @@ -298,7 +300,13 @@ def real_roots_idx(self): return self.deps(_ROOT_HEAD) def set_origin(self, origin): - """Update the origin of this annotation""" + """Update the origin of this annotation. + + Parameters + ---------- + origin : FileId + File identifier of the origin of this annotation. + """ self.origin = origin def spans(self): From ca2f2ac5463cd1529ecf2ea491c85a376ff5b603 Mon Sep 17 00:00:00 2001 From: moreymat Date: Fri, 10 Feb 2017 15:27:59 +0100 Subject: [PATCH 23/44] DOC warn about the bug-prone API of deptree_to_simple_rst_tree --- educe/rst_dt/dep2con.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/educe/rst_dt/dep2con.py b/educe/rst_dt/dep2con.py index e24eabd..0dab3ed 100644 --- a/educe/rst_dt/dep2con.py +++ b/educe/rst_dt/dep2con.py @@ -592,6 +592,15 @@ def deptree_to_simple_rst_tree(dtree, allow_forest=False): (and so on, until all we have left is a single RST tree). + + TODO + ---- + * [ ] fix the signature of this function: change name or arguments + or return type, because the current implementation returns + either a SimpleRSTTree if allow_forest=False, or a list of + SimpleRSTTree if allow_forest=True. This is a likely source of + errors because SimpleRSTTrees are list-like, ie. tree[i] + returns the i-th child of a tree node... """ def walk(ancestor, subtree): """ From f808be11fdd798822dc5ca06a5b0bbf605347a60 Mon Sep 17 00:00:00 2001 From: moreymat Date: Tue, 14 Feb 2017 16:51:19 +0100 Subject: [PATCH 24/44] FIX dump in feature extraction for STAC --- educe/learning/keygroup_vectorizer.py | 19 +++++++++---------- educe/stac/learning/cmd/extract.py | 17 ++++++++++++++--- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/educe/learning/keygroup_vectorizer.py b/educe/learning/keygroup_vectorizer.py index ae6735d..c008472 100644 --- a/educe/learning/keygroup_vectorizer.py +++ b/educe/learning/keygroup_vectorizer.py @@ -58,7 +58,6 @@ def _count_vocab(self, vectors, fixed_vocab=False): doc_ptr = [] doc_ptr.append(0) - print('fit vocab') # DEBUG for vecs in vectors: for vec in vecs: for feature, featval in vec.one_hot_values_gen(): @@ -70,7 +69,6 @@ def _count_vocab(self, vectors, fixed_vocab=False): continue row_ptr.append(len(feature_acc)) doc_ptr.append(row_ptr[-1]) - print('vocab done') # DEBUG if not fixed_vocab: vocabulary = dict(vocabulary) @@ -86,16 +84,17 @@ def _count_vocab(self, vectors, fixed_vocab=False): # start a new doc matrix X.append([]) doc_nxt += 1 - print('doc ', str(i)) # DEBUG + # print('doc ', str(i)) # DEBUG x = feature_acc[current_row:next_row] X[-1].append(x) - # DEBUG - n_edus = [len(y) for y in X] - print(len(vocabulary), sys.getsizeof(vocabulary)) - print(len(X), sum(len(y) for y in X), sys.getsizeof(X)) - print(sum(nb_edus * (nb_edus - 1) for nb_edus in n_edus)) - raise ValueError('woopti') - # end DEBUG + + if False: # DEBUG + n_edus = [len(y) for y in X] + print(len(vocabulary), sys.getsizeof(vocabulary)) + print(len(X), sum(len(y) for y in X), sys.getsizeof(X)) + print(sum(nb_edus * (nb_edus - 1) for nb_edus in n_edus)) + raise ValueError('woopti') + return vocabulary, X def fit_transform(self, vectors): diff --git a/educe/stac/learning/cmd/extract.py b/educe/stac/learning/cmd/extract.py index 2baa2f2..eaa15d1 100755 --- a/educe/stac/learning/cmd/extract.py +++ b/educe/stac/learning/cmd/extract.py @@ -74,6 +74,18 @@ def config_argparser(parser): choices=['head', 'broadcast', 'custom'], default='head', help='CDUs stripping method (if going into CDUs)') + # WIP 2017-02-02 toggle between corpus- and doc-centric feature + # extraction + parser.add_argument('--file_split', + choices=['dialogue', 'corpus'], + default='corpus', + help=("Level of granularity for each set of " + "files: " + "'dialogue' produces one set of files per " + "dialogue ; " + "'corpus' one set of files per corpus " + "split (e.g. 'train', 'test')")) + # end WIP toggle between corpus- and doc-centric feature extraction parser.set_defaults(func=main) # --------------------------------------------------------------------- @@ -189,15 +201,14 @@ def main_pairs(args): out_file = fp.join(outdir_corpus, '{dia_id}.relations.sparse'.format( dia_id=dia_id)) - dump_all(X, y, out_file, dia, instance_generator) + dump_all(X, y, out_file, [dia], instance_generator) elif args.file_split == 'corpus': # one file per corpus (in fact corpus split) # these paths should go away once we switch to a proper dumper out_file = fp.join(outdir, corpus_name + '.relations.sparse') - dump_all(X_gen, y_gen, out_file, labtor.labelset_, dialogues, - instance_generator) + dump_all(X_gen, y_gen, out_file, dialogues, instance_generator) # end WIP # dump vocabulary From 6a0cb77986d4a25430783ea6419ee50b20c44cb1 Mon Sep 17 00:00:00 2001 From: moreymat Date: Tue, 14 Feb 2017 20:35:55 +0100 Subject: [PATCH 25/44] FIX load and dump labels for STAC --- educe/stac/learning/cmd/extract.py | 77 +++++++++++++++++++-------- educe/stac/learning/doc_vectorizer.py | 4 +- 2 files changed, 58 insertions(+), 23 deletions(-) diff --git a/educe/stac/learning/cmd/extract.py b/educe/stac/learning/cmd/extract.py index eaa15d1..5ce4b11 100755 --- a/educe/stac/learning/cmd/extract.py +++ b/educe/stac/learning/cmd/extract.py @@ -9,29 +9,28 @@ """ from __future__ import print_function -from os import path as fp + import itertools import os +from os import path as fp import sys +import educe.corpus +import educe.glozz +from educe.learning.edu_input_format import (dump_all, dump_labels, + dump_svmlight_file, + dump_edu_input_file, + load_labels) from educe.learning.keygroup_vectorizer import (KeyGroupVectorizer) +from educe.learning.vocabulary_format import (dump_vocabulary, + load_vocabulary) +import educe.stac from educe.stac.annotation import (DIALOGUE_ACTS, SUBORDINATING_RELATIONS, COORDINATING_RELATIONS) from educe.stac.learning.doc_vectorizer import ( DialogueActVectorizer, LabelVectorizer, mk_high_level_dialogues, extract_pair_features, extract_single_features, read_corpus_inputs) - - -import educe.corpus -from educe.learning.edu_input_format import (dump_all, - labels_comment, - dump_svmlight_file, - dump_edu_input_file) -from educe.learning.vocabulary_format import (dump_vocabulary, - load_vocabulary) -import educe.glozz -import educe.stac import educe.util @@ -68,6 +67,13 @@ def config_argparser(parser): parser.add_argument('--vocabulary', metavar='FILE', help='Vocabulary file (for --parsing mode)') + # 2017-02-14 file for labels, for discourse relations + parser.add_argument('--labels', + metavar='FILE', + help='Read label set from given file ' + '(important when extracting test data for ' + 'discourse relations)') + # end labels parser.add_argument('--ignore-cdus', action='store_true', help='Avoid going into CDUs') parser.add_argument('--strip-mode', @@ -107,7 +113,15 @@ def main_single(args): # TODO? just transform() if args.parsing or args.vocabulary? X_gen = vzer.fit_transform(feats) # pylint: enable=invalid-name - labtor = DialogueActVectorizer(instance_generator, DIALOGUE_ACTS) + if args.labels is not None: + labelset = load_labels(args.labels) + labels = [lbl for lbl, idx + in sorted(labelset.items(), key=lambda k, v: v)] + # DialogueActVectorizer.__init__ reserves (0, UNK) + labels = labels[1:] + else: + labels = sorted(DIALOGUE_ACTS) + labtor = DialogueActVectorizer(instance_generator, labels) y_gen = labtor.transform(dialogues) # create directory structure: {output}/ @@ -117,9 +131,6 @@ def main_single(args): corpus_name = fp.basename(args.corpus) - # list dialogue acts - comment = labels_comment(labtor.labelset_) - # dump: EDUs, pairings, vectorized pairings with label # WIP switch to a document (here dialogue) centric generation of data # 1. create a folder for the corpus: {output}/{corpus}/ @@ -129,6 +140,7 @@ def main_single(args): # 2. dump edu_input and features files if args.file_split == 'dialogue': # one file per dialogue + # pylint: disable=invalid-name for dia, X, y in itertools.izip(dialogues, X_gen, y_gen): dia_id = dia.grouping print('dump dialogue', dia_id) @@ -138,7 +150,8 @@ def main_single(args): dia_id=dia_id)) edu_input_file = '{feat_file}.edu_input'.format(feat_file=feat_file) dump_edu_input_file(dia, edu_input_file) - dump_svmlight_file(X, y, feat_file, comment=comment) + dump_svmlight_file(X, y, feat_file) + # pylint: enable=invalid-name elif args.file_split == 'corpus': # one file per corpus (in fact, corpus split) # these paths should go away once we switch to a proper dumper @@ -146,7 +159,7 @@ def main_single(args): corpus_name + '.dialogue-acts.sparse') edu_input_file = out_file + '.edu_input' dump_edu_input_file(dialogues, edu_input_file) - dump_svmlight_file(X_gen, y_gen, out_file, comment=comment) + dump_svmlight_file(X_gen, y_gen, out_file) # end WIP # dump vocabulary @@ -157,6 +170,11 @@ def main_single(args): '{corpus_name}.dialogue-acts.sparse.vocab'.format( corpus_name=corpus_name)) dump_vocabulary(vzer.vocabulary_, vocab_file) + # dump labels + labels_file = fp.join(outdir, + '{corpus_name}.dialogue-acts.labels'.format( + corpus_name=corpus_name)) + dump_labels(labtor.labelset_, labels_file) def main_pairs(args): @@ -166,8 +184,17 @@ def main_pairs(args): dialogues = list(mk_high_level_dialogues(inputs, stage)) instance_generator = lambda x: x.edu_pairs() - labels = frozenset(SUBORDINATING_RELATIONS + - COORDINATING_RELATIONS) + if args.labels is not None: + labelset = load_labels(args.labels) + labels = [lbl for lbl, idx + in sorted(labelset.items(), key=lambda k, v: v)] + # LabelVectorizer.__init__ automatically reserves the first three + # indices: (0, UNK), (1, ROOT), (2, UNRELATED) + labels = labels[3:] + else: + labels = frozenset(SUBORDINATING_RELATIONS + + COORDINATING_RELATIONS) + labels = sorted(labels) # pylint: disable=invalid-name # X, y follow the naming convention in sklearn @@ -190,11 +217,13 @@ def main_pairs(args): corpus_name = fp.basename(args.corpus) - # WIP switch to a document (here dialogue) centric generation of data outdir_corpus = fp.join(outdir, corpus_name) if not fp.exists(outdir_corpus): os.makedirs(outdir_corpus) + + # WIP switch to a document (here dialogue) centric generation of data if args.file_split == 'dialogue': + # pylint: disable=invalid-name for dia, X, y in itertools.izip(dialogues, X_gen, y_gen): dia_id = dia.grouping # these paths should go away once we switch to a proper dumper @@ -202,6 +231,7 @@ def main_pairs(args): '{dia_id}.relations.sparse'.format( dia_id=dia_id)) dump_all(X, y, out_file, [dia], instance_generator) + # pylint: enable=invalid-name elif args.file_split == 'corpus': # one file per corpus (in fact corpus split) # these paths should go away once we switch to a proper dumper @@ -216,6 +246,11 @@ def main_pairs(args): '{corpus_name}.relations.sparse.vocab'.format( corpus_name=corpus_name)) dump_vocabulary(vzer.vocabulary_, vocab_file) + # dump labels + labels_file = fp.join(outdir, + '{corpus_name}.relations.labels'.format( + corpus_name=corpus_name)) + dump_labels(labtor.labelset_, labels_file) def main(args): diff --git a/educe/stac/learning/doc_vectorizer.py b/educe/stac/learning/doc_vectorizer.py index d275e7c..02e04a3 100644 --- a/educe/stac/learning/doc_vectorizer.py +++ b/educe/stac/learning/doc_vectorizer.py @@ -39,7 +39,7 @@ def __init__(self, instance_generator, labels): """ instance_generator to enumerate the instances from a doc - :type labels: set(string) + :type labels: list(string) """ self.instance_generator = instance_generator self.labelset_ = {l: i for i, l in enumerate(labels, start=1)} @@ -77,7 +77,7 @@ class LabelVectorizer(object): instance_generator : fun(doc) -> :obj:`list` of (EDU, EDU) Function to enumerate the instances from a doc. - labels : :obj:`set` of str + labels : :obj:`list` of str Labelset zero : boolean, defaults to False From e0bf3dc195abc2438266a8a4b86fa1d2b56c9a92 Mon Sep 17 00:00:00 2001 From: moreymat Date: Fri, 17 Feb 2017 16:24:11 +0100 Subject: [PATCH 26/44] FIX paths to data files under data/{corpus_name}/ --- educe/stac/learning/cmd/extract.py | 32 +++++++++++++++++++----------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/educe/stac/learning/cmd/extract.py b/educe/stac/learning/cmd/extract.py index 5ce4b11..e901070 100755 --- a/educe/stac/learning/cmd/extract.py +++ b/educe/stac/learning/cmd/extract.py @@ -145,19 +145,24 @@ def main_single(args): dia_id = dia.grouping print('dump dialogue', dia_id) # these paths should go away once we switch to a proper dumper - feat_file = fp.join(outdir_corpus, - '{dia_id}.dialogue-acts.sparse'.format( - dia_id=dia_id)) - edu_input_file = '{feat_file}.edu_input'.format(feat_file=feat_file) + feat_file = fp.join( + outdir_corpus, + '{dia_id}.dialogue-acts.sparse'.format( + dia_id=dia_id)) + edu_input_file = '{feat_file}.edu_input'.format( + feat_file=feat_file) dump_edu_input_file(dia, edu_input_file) dump_svmlight_file(X, y, feat_file) # pylint: enable=invalid-name elif args.file_split == 'corpus': # one file per corpus (in fact, corpus split) # these paths should go away once we switch to a proper dumper - out_file = fp.join(outdir, - corpus_name + '.dialogue-acts.sparse') - edu_input_file = out_file + '.edu_input' + out_file = fp.join( + outdir_corpus, + '{corpus_name}.dialogue-acts.sparse'.format( + corpus_name=corpus_name)) + edu_input_file = '{out_file}.edu_input'.format( + out_file=out_file) dump_edu_input_file(dialogues, edu_input_file) dump_svmlight_file(X_gen, y_gen, out_file) # end WIP @@ -227,16 +232,19 @@ def main_pairs(args): for dia, X, y in itertools.izip(dialogues, X_gen, y_gen): dia_id = dia.grouping # these paths should go away once we switch to a proper dumper - out_file = fp.join(outdir_corpus, - '{dia_id}.relations.sparse'.format( - dia_id=dia_id)) + out_file = fp.join( + outdir_corpus, + '{dia_id}.relations.sparse'.format( + dia_id=dia_id)) dump_all(X, y, out_file, [dia], instance_generator) # pylint: enable=invalid-name elif args.file_split == 'corpus': # one file per corpus (in fact corpus split) # these paths should go away once we switch to a proper dumper - out_file = fp.join(outdir, - corpus_name + '.relations.sparse') + out_file = fp.join( + outdir_corpus, + '{corpus_name}.relations.sparse'.format( + corpus_name=corpus_name)) dump_all(X_gen, y_gen, out_file, dialogues, instance_generator) # end WIP From b920785d6cb773c737b90581c6ff270a5d815a55 Mon Sep 17 00:00:00 2001 From: moreymat Date: Tue, 11 Apr 2017 10:29:03 +0200 Subject: [PATCH 27/44] FIX rename metrics to S, N, R, F --- educe/metrics/parseval.py | 2 +- educe/rst_dt/metrics/rst_parseval.py | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/educe/metrics/parseval.py b/educe/metrics/parseval.py index 74057b9..0314653 100644 --- a/educe/metrics/parseval.py +++ b/educe/metrics/parseval.py @@ -164,7 +164,7 @@ def parseval_report(ctree_true, ctree_pred, exclude_root=False, TODO metric_types: list of strings, optional Metrics that need to be included in the report ; if None is - given, defaults to ['S', 'S+N', 'S+R', 'S+N+R']. + given, defaults to ['S', 'N', 'R', 'F']. digits: int, defaults to 4 Number of decimals to print. print_support_pred: boolean, defaults to True diff --git a/educe/rst_dt/metrics/rst_parseval.py b/educe/rst_dt/metrics/rst_parseval.py index af8ba1c..65547f2 100644 --- a/educe/rst_dt/metrics/rst_parseval.py +++ b/educe/rst_dt/metrics/rst_parseval.py @@ -16,14 +16,14 @@ # label extraction functions LBL_FNS = [ ('S', lambda span: 1), - ('S+N', lambda span: span[1]), - ('S+R', lambda span: span[2]), - ('S+N+R', lambda span: '{}-{}'.format(span[2], span[1])), + ('N', lambda span: span[1]), + ('R', lambda span: span[2]), + ('F', lambda span: '{}-{}'.format(span[2], span[1])), # WIP 2016-11-10 add head to evals ('S+H', lambda span: span[3]), - ('S+N+H', lambda span: '{}-{}'.format(span[1], span[3])), - ('S+R+H', lambda span: '{}-{}'.format(span[2], span[3])), - ('S+N+R+H', lambda span: '{}-{}'.format(span[2], span[1])), + ('N+H', lambda span: '{}-{}'.format(span[1], span[3])), + ('R+H', lambda span: '{}-{}'.format(span[2], span[3])), + ('F+H', lambda span: '{}-{}'.format(span[2], span[1])), # end WIP head ] @@ -115,7 +115,7 @@ def rst_parseval_report(ctree_true, ctree_pred, ctree_type='RST', metric_types : list of strings, optional Metrics that need to be included in the report ; if None is - given, defaults to ['S', 'S+N', 'S+R', 'S+N+R']. + given, defaults to ['S', 'N', 'R', 'F']. digits : int, defaults to 4 Number of decimals to print. @@ -156,7 +156,7 @@ def rst_parseval_report(ctree_true, ctree_pred, ctree_type='RST', # select metrics and the corresponding functions if metric_types is None: - # metric_types = ['S', 'S+N', 'S+R', 'S+N+R'] + # metric_types = ['S', 'N', 'R', 'F'] metric_types = [x[0] for x in LBL_FNS] if set(metric_types) - set(x[0] for x in LBL_FNS): raise ValueError('Unknown metric types in {}'.format(metric_types)) @@ -173,7 +173,7 @@ def rst_parseval_report(ctree_true, ctree_pred, ctree_type='RST', def rst_parseval_detailed_report(ctree_true, ctree_pred, ctree_type='RST', - subtree_filter=None, metric_type='S+R', + subtree_filter=None, metric_type='R', labels=None, sort_by_support=True, digits=4, per_doc=False): """Build a text report showing the PARSEVAL discourse metrics per label. @@ -203,7 +203,7 @@ def rst_parseval_detailed_report(ctree_true, ctree_pred, ctree_type='RST', subtree_filter: function, optional Function to filter all local trees. - metric_type : one of {'S+R', 'S+N+R'}, defaults to 'S+R' + metric_type : one of {'R', 'F'}, defaults to 'R' Metric that need to be included in the report. digits : int, defaults to 4 From eef2bbc808fa4318488530bb2fb31010b386510d Mon Sep 17 00:00:00 2001 From: moreymat Date: Tue, 11 Apr 2017 17:26:16 +0200 Subject: [PATCH 28/44] ENH parseval_compact_report --- educe/metrics/parseval.py | 93 ++++++++++++++++++++++++++ educe/rst_dt/metrics/rst_parseval.py | 99 ++++++++++++++++++++++++++++ 2 files changed, 192 insertions(+) diff --git a/educe/metrics/parseval.py b/educe/metrics/parseval.py index 0314653..c9aaec1 100644 --- a/educe/metrics/parseval.py +++ b/educe/metrics/parseval.py @@ -9,6 +9,7 @@ """ from __future__ import absolute_import, print_function +import warnings import numpy as np @@ -146,6 +147,98 @@ def parseval_scores(ctree_true, ctree_pred, subtree_filter=None, return p, r, f1, s_true, s_pred, labels +def parseval_compact_report(ctree_true, parser_preds, + exclude_root=False, subtree_filter=None, + lbl_fns=None, digits=4, + print_support=True, + per_doc=False, + add_trivial_spans=False): + """Build a text report showing the F1-scores of the PARSEVAL metrics + for a list of parsers. + + This is the simplest and most compact report we need to generate, it + corresponds to the comparative arrays of results from the literature. + Metrics are calculated globally (average='micro'), unless per_doc is + True (macro-averaging across documents). + + Parameters + ---------- + ctree_true: TODO + TODO + + parser_preds: list of (parser_name, ctree_pred) + Predicted c-trees for each parser. + + metric_types: list of strings, optional + Metrics that need to be included in the report ; if None is + given, defaults to ['S', 'N', 'R', 'F']. + + digits: int, defaults to 4 + Number of decimals to print. + + span_sel: TODO + TODO + + per_doc: boolean, defaults to False + If True, compute p, r, f for each doc separately then compute the + mean of each score over docs. This is *not* the correct + implementation, but it corresponds to that in DPLP. + """ + if lbl_fns is None: + # we require a labelled span to be a pair (span, lbl) + # where span and lbl can be anything, for example + # * span = (span_beg, span_end) + # * lbl = (nuc, rel) + lbl_fns = [('Labelled Span', lambda span_lbl: span_lbl[1])] + + metric_types = [k for k, v in lbl_fns] + + # prepare scaffold for report + width = max(len(parser_name) for parser_name, _ in parser_preds) + + headers = [x for x in metric_types] + if print_support: + headers += ["support"] + fmt = '%% %ds' % width # first col: parser name + fmt += ' ' + fmt += ' '.join(['% 9s' for _ in headers]) + fmt += '\n' + headers = [""] + headers + report = fmt % tuple(headers) + report += '\n' + + for parser_name, ctree_pred in parser_preds: + values = [parser_name] + # compute scores + metric_scores = dict() + for metric_type, lbl_fn in lbl_fns: + p, r, f1, s_true, s_pred, labels = parseval_scores( + ctree_true, ctree_pred, subtree_filter=subtree_filter, + exclude_root=exclude_root, lbl_fn=lbl_fn, labels=None, + average='micro', per_doc=per_doc, + add_trivial_spans=add_trivial_spans) + metric_scores[metric_type] = (p, r, f1, s_true, s_pred) + + # fill report + support = 0 + for metric_type in metric_types: + (p, r, f1, s_true, s_pred) = metric_scores[metric_type] + values += ["{0:0.{1}f}".format(f1, digits)] + # (warning) support in _true and _pred should be the same ; + if s_true != s_pred: + warnings.warn("s_pred != s_true") + # store support in _true, for optional display below + if support == 0: + support = s_true + # append support + if print_support: + values += ["{0}".format(support)] # support_true + + report += fmt % tuple(values) + + return report + + def parseval_report(ctree_true, ctree_pred, exclude_root=False, subtree_filter=None, lbl_fns=None, digits=4, print_support_pred=True, per_doc=False, diff --git a/educe/rst_dt/metrics/rst_parseval.py b/educe/rst_dt/metrics/rst_parseval.py index 65547f2..bb74159 100644 --- a/educe/rst_dt/metrics/rst_parseval.py +++ b/educe/rst_dt/metrics/rst_parseval.py @@ -10,6 +10,7 @@ from __future__ import absolute_import, print_function from educe.metrics.parseval import (parseval_scores, parseval_report, + parseval_compact_report, parseval_detailed_report) @@ -78,6 +79,104 @@ def rst_parseval_scores(ctree_true, ctree_pred, lbl_fn, subtree_filter=None, labels=labels, average=average) +def rst_parseval_compact_report(ctree_true, parser_preds, + ctree_type='RST', subtree_filter=None, + metric_types=None, digits=4, + print_support=True, + per_doc=False, + add_trivial_spans=False, + stringent=False): + """Build a compact text report showing the f1-scores of the PARSEVAL + discourse metrics. + + This is the simplest report we need to generate, it corresponds + to the arrays of results from the literature. + Metrics are calculated globally (average='micro'). + + Parameters + ---------- + ctree_true: TODO + TODO + + parser_preds: List of (parser_name, List of ctree_pred) + List of predictions for each parser. + + ctree_type : one of {'RST', 'SimpleRST'}, defaults to 'RST' + Type of ctrees considered in the evaluation procedure. + 'RST' is the standard type of ctrees used in the RST corpus, + it triggers the exclusion of the root node from the evaluation + but leaves are kept. + 'SimpleRST' is a binarized variant of RST trees where each + internal node corresponds to an attachment decision ; in other + words, it is a binary ctree where the nuclearity and relation label + are moved one node up compared to the standard RST trees. This + triggers the exclusion of leaves from the eval, but the root node + is kept. + + subtree_filter: function, optional + Function to filter all local trees. + + metric_types : list of strings, optional + Metrics that need to be included in the report ; if None is + given, defaults to ['S', 'N', 'R', 'F']. + + digits : int, defaults to 4 + Number of decimals to print. + + print_support : boolean, defaults to True + If True, the true support, i.e. the number of reference spans, + is also displayed. This is useful to control whether the + reference ctrees have been binarized. + + per_doc : boolean, defaults to False + If True, compute p, r, f for each doc separately then compute the + mean of each score over docs. This is *not* the correct + implementation, but it corresponds to that in DPLP. + + add_trivial_spans : boolean, defaults to False + If True, trivial spans 0-0, 0-n, 1-n are added ; this is meant to + replicate the evaluation procedure of Li et al.'s dependency RST + parser. + + stringent : boolean, defaults to False + TODO + """ + # filter root or leaves, depending on the type of ctree + if ctree_type not in ['RST', 'SimpleRST']: + raise ValueError("ctree_type should be one of {'RST', 'SimpleRST'}") + if ctree_type == 'RST': + # standard RST ctree: exclude root + exclude_root = True + subtree_filter = subtree_filter + elif ctree_type == 'SimpleRST': + # SimpleRST variant: keep root, exclude leaves + exclude_root = False # TODO try True first, should get same as before + not_leaf = lambda t: t.height() > 2 # TODO unit test! + if subtree_filter is None: + subtree_filter = not_leaf + else: + subtree_filter = lambda t: not_leaf(t) and subtree_filter(t) + + # select metrics and the corresponding functions + if metric_types is None: + # metric_types = ['S', 'N', 'R', 'F'] + metric_types = [x[0] for x in LBL_FNS] + if set(metric_types) - set(x[0] for x in LBL_FNS): + raise ValueError('Unknown metric types in {}'.format(metric_types)) + metric2lbl_fn = dict(LBL_FNS) + lbl_fns = [(metric_type, metric2lbl_fn[metric_type]) + for metric_type in metric_types] + + return parseval_compact_report(ctree_true, parser_preds, + exclude_root=exclude_root, + subtree_filter=subtree_filter, + lbl_fns=lbl_fns, + digits=digits, + print_support=print_support, + per_doc=per_doc, + add_trivial_spans=add_trivial_spans) + + def rst_parseval_report(ctree_true, ctree_pred, ctree_type='RST', subtree_filter=None, metric_types=None, digits=4, print_support_pred=True, From b9d8a563a661ca15fc8cd86a4afcfcb329f266e7 Mon Sep 17 00:00:00 2001 From: moreymat Date: Wed, 12 Apr 2017 11:55:48 +0200 Subject: [PATCH 29/44] ENH ctree spans can be in chars --- educe/metrics/parseval.py | 18 +++++++++--- educe/rst_dt/annotation.py | 44 +++++++++++++++++++++------- educe/rst_dt/metrics/rst_parseval.py | 10 +++++-- 3 files changed, 56 insertions(+), 16 deletions(-) diff --git a/educe/metrics/parseval.py b/educe/metrics/parseval.py index c9aaec1..f1ff2be 100644 --- a/educe/metrics/parseval.py +++ b/educe/metrics/parseval.py @@ -19,6 +19,7 @@ def parseval_scores(ctree_true, ctree_pred, subtree_filter=None, exclude_root=False, lbl_fn=None, labels=None, + span_type='edus', average=None, per_doc=False, add_trivial_spans=False): """Compute PARSEVAL scores for ctree_pred wrt ctree_true. @@ -80,10 +81,12 @@ def parseval_scores(ctree_true, ctree_pred, subtree_filter=None, # extract descriptions of spans from the true and pred trees spans_true = [ct.get_spans(subtree_filter=subtree_filter, - exclude_root=exclude_root) + exclude_root=exclude_root, + span_type=span_type) for ct in ctree_true] spans_pred = [ct.get_spans(subtree_filter=subtree_filter, - exclude_root=exclude_root) + exclude_root=exclude_root, + span_type=span_type) for ct in ctree_pred] # WIP replicate eval in Li et al.'s dep parser @@ -149,7 +152,9 @@ def parseval_scores(ctree_true, ctree_pred, subtree_filter=None, def parseval_compact_report(ctree_true, parser_preds, exclude_root=False, subtree_filter=None, - lbl_fns=None, digits=4, + lbl_fns=None, + span_type='edus', + digits=4, print_support=True, per_doc=False, add_trivial_spans=False): @@ -215,6 +220,7 @@ def parseval_compact_report(ctree_true, parser_preds, p, r, f1, s_true, s_pred, labels = parseval_scores( ctree_true, ctree_pred, subtree_filter=subtree_filter, exclude_root=exclude_root, lbl_fn=lbl_fn, labels=None, + span_type=span_type, average='micro', per_doc=per_doc, add_trivial_spans=add_trivial_spans) metric_scores[metric_type] = (p, r, f1, s_true, s_pred) @@ -240,7 +246,8 @@ def parseval_compact_report(ctree_true, parser_preds, def parseval_report(ctree_true, ctree_pred, exclude_root=False, - subtree_filter=None, lbl_fns=None, digits=4, + subtree_filter=None, lbl_fns=None, span_type='edus', + digits=4, print_support_pred=True, per_doc=False, add_trivial_spans=False): """Build a text report showing the PARSEVAL discourse metrics. @@ -298,6 +305,7 @@ def parseval_report(ctree_true, ctree_pred, exclude_root=False, p, r, f1, s_true, s_pred, labels = parseval_scores( ctree_true, ctree_pred, subtree_filter=subtree_filter, exclude_root=exclude_root, lbl_fn=lbl_fn, labels=None, + span_type=span_type, average='micro', per_doc=per_doc, add_trivial_spans=add_trivial_spans) metric_scores[metric_type] = (p, r, f1, s_true, s_pred) @@ -317,6 +325,7 @@ def parseval_report(ctree_true, ctree_pred, exclude_root=False, def parseval_detailed_report(ctree_true, ctree_pred, exclude_root=False, subtree_filter=None, lbl_fn=None, + span_type='edus', labels=None, sort_by_support=True, digits=4, per_doc=False): """Build a text report showing the PARSEVAL discourse metrics. @@ -362,6 +371,7 @@ class (or micro-averaged over all classes). p, r, f1, s_true, s_pred, labels = parseval_scores( ctree_true, ctree_pred, subtree_filter=subtree_filter, exclude_root=exclude_root, lbl_fn=lbl_fn, labels=labels, + span_type=span_type, average=None, per_doc=per_doc) # scaffold for report diff --git a/educe/rst_dt/annotation.py b/educe/rst_dt/annotation.py index e992cf6..82b5df4 100644 --- a/educe/rst_dt/annotation.py +++ b/educe/rst_dt/annotation.py @@ -397,7 +397,8 @@ def edu_span(self): """ return treenode(self).edu_span - def get_spans(self, subtree_filter=None, exclude_root=False): + def get_spans(self, subtree_filter=None, exclude_root=False, + span_type='edus'): """Get the spans of a constituency tree. Each span is described by a triplet (edu_span, nuclearity, @@ -405,15 +406,20 @@ def get_spans(self, subtree_filter=None, exclude_root=False): Parameters ---------- - subtree_filter: function, defaults to None + subtree_filter : function, defaults to None Function to filter all local trees. - exclude_root: boolean, defaults to False + exclude_root : boolean, defaults to False If True, exclude the span of the root node. This cannot be expressed with `subtree_filter` because the latter is limited to properties local to each subtree in isolation. Or maybe I just missed something. + span_type : one of {'edus', 'chars'} + Whether each span is expressed on EDU or character indices. + Character indices are useful to compare spans from trees + whose EDU segmentation differs. + Returns ------- spans: list of tuple((int, int), str, str) @@ -425,8 +431,14 @@ def get_spans(self, subtree_filter=None, exclude_root=False): if exclude_root: tnodes = tnodes[1:] # 2016-11-10 add a 4th element: head - spans = [(tn.edu_span, tn.nuclearity, tn.rel, tn.head) - for tn in tnodes] + # 2017-04-12 enable char spans + if span_type == 'chars': + spans = [((tn.span.char_start, tn.span.char_end), + tn.nuclearity, tn.rel, tn.head) + for tn in tnodes] + else: + spans = [(tn.edu_span, tn.nuclearity, tn.rel, tn.head) + for tn in tnodes] return spans def text(self): @@ -492,7 +504,8 @@ def text_span(self): def _members(self): return list(self) # children - def get_spans(self, subtree_filter=None, exclude_root=False): + def get_spans(self, subtree_filter=None, exclude_root=False, + span_type='edus'): """Get the spans of a constituency tree. Each span is described by a triplet (edu_span, nuclearity, @@ -500,15 +513,20 @@ def get_spans(self, subtree_filter=None, exclude_root=False): Parameters ---------- - subtree_filter: function, defaults to None + subtree_filter : function, defaults to None Function to filter all local trees. - exclude_root: boolean, defaults to False + exclude_root : boolean, defaults to False If True, exclude the span of the root node. This cannot be expressed with `subtree_filter` because the latter is limited to properties local to each subtree in isolation. Or maybe I just missed something. + span_type : one of {'edus', 'chars'} + Whether each span is expressed on EDU or character indices. + Character indices are useful to compare spans from trees + whose EDU segmentation differs. + Returns ------- spans: list of tuple((int, int), str, str) @@ -520,8 +538,14 @@ def get_spans(self, subtree_filter=None, exclude_root=False): if exclude_root: tnodes = tnodes[1:] # 2016-11-10 add a 4th element: head - spans = [(tn.edu_span, tn.nuclearity, tn.rel, tn.head) - for tn in tnodes] + # 2017-04-12 enable char spans + if span_type == 'chars': + spans = [((tn.span.char_start, tn.span.char_end), + tn.nuclearity, tn.rel, tn.head) + for tn in tnodes] + else: + spans = [(tn.edu_span, tn.nuclearity, tn.rel, tn.head) + for tn in tnodes] return spans @classmethod diff --git a/educe/rst_dt/metrics/rst_parseval.py b/educe/rst_dt/metrics/rst_parseval.py index bb74159..c560c50 100644 --- a/educe/rst_dt/metrics/rst_parseval.py +++ b/educe/rst_dt/metrics/rst_parseval.py @@ -81,6 +81,7 @@ def rst_parseval_scores(ctree_true, ctree_pred, lbl_fn, subtree_filter=None, def rst_parseval_compact_report(ctree_true, parser_preds, ctree_type='RST', subtree_filter=None, + span_type='edus', metric_types=None, digits=4, print_support=True, per_doc=False, @@ -171,6 +172,7 @@ def rst_parseval_compact_report(ctree_true, parser_preds, exclude_root=exclude_root, subtree_filter=subtree_filter, lbl_fns=lbl_fns, + span_type=span_type, digits=digits, print_support=print_support, per_doc=per_doc, @@ -179,6 +181,7 @@ def rst_parseval_compact_report(ctree_true, parser_preds, def rst_parseval_report(ctree_true, ctree_pred, ctree_type='RST', subtree_filter=None, metric_types=None, + span_type='edus', digits=4, print_support_pred=True, per_doc=False, add_trivial_spans=False, @@ -265,6 +268,7 @@ def rst_parseval_report(ctree_true, ctree_pred, ctree_type='RST', return parseval_report(ctree_true, ctree_pred, exclude_root=exclude_root, subtree_filter=subtree_filter, lbl_fns=lbl_fns, + span_type=span_type, digits=digits, print_support_pred=print_support_pred, per_doc=per_doc, @@ -272,7 +276,8 @@ def rst_parseval_report(ctree_true, ctree_pred, ctree_type='RST', def rst_parseval_detailed_report(ctree_true, ctree_pred, ctree_type='RST', - subtree_filter=None, metric_type='R', + subtree_filter=None, span_type='edus', + metric_type='R', labels=None, sort_by_support=True, digits=4, per_doc=False): """Build a text report showing the PARSEVAL discourse metrics per label. @@ -338,6 +343,7 @@ def rst_parseval_detailed_report(ctree_true, ctree_pred, ctree_type='RST', return parseval_detailed_report( ctree_true, ctree_pred, exclude_root=exclude_root, - subtree_filter=subtree_filter, lbl_fn=lbl_fn, + subtree_filter=subtree_filter, span_type=span_type, + lbl_fn=lbl_fn, labels=labels, sort_by_support=sort_by_support, digits=digits, per_doc=per_doc) From 62cab3ed479d74ce3d8e04cf7627fa073075f4c5 Mon Sep 17 00:00:00 2001 From: moreymat Date: Tue, 16 May 2017 13:38:35 +0200 Subject: [PATCH 30/44] ENH display percentages --- educe/metrics/parseval.py | 32 ++++++++++++++++++++++------ educe/rst_dt/metrics/rst_parseval.py | 9 +++++--- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/educe/metrics/parseval.py b/educe/metrics/parseval.py index f1ff2be..bf38240 100644 --- a/educe/metrics/parseval.py +++ b/educe/metrics/parseval.py @@ -155,6 +155,7 @@ def parseval_compact_report(ctree_true, parser_preds, lbl_fns=None, span_type='edus', digits=4, + percent=False, print_support=True, per_doc=False, add_trivial_spans=False): @@ -212,6 +213,10 @@ def parseval_compact_report(ctree_true, parser_preds, report = fmt % tuple(headers) report += '\n' + # display percentages + if percent: + digits = digits - 2 + for parser_name, ctree_pred in parser_preds: values = [parser_name] # compute scores @@ -229,7 +234,8 @@ def parseval_compact_report(ctree_true, parser_preds, support = 0 for metric_type in metric_types: (p, r, f1, s_true, s_pred) = metric_scores[metric_type] - values += ["{0:0.{1}f}".format(f1, digits)] + values += ["{0:0.{1}f}".format(f1 * 100.0 if percent else f1, + digits)] # (warning) support in _true and _pred should be the same ; if s_true != s_pred: warnings.warn("s_pred != s_true") @@ -238,7 +244,7 @@ def parseval_compact_report(ctree_true, parser_preds, support = s_true # append support if print_support: - values += ["{0}".format(support)] # support_true + values += ["{0:.0f}".format(support)] # support_true report += fmt % tuple(values) @@ -247,7 +253,7 @@ def parseval_compact_report(ctree_true, parser_preds, def parseval_report(ctree_true, ctree_pred, exclude_root=False, subtree_filter=None, lbl_fns=None, span_type='edus', - digits=4, + digits=4, percent=False, print_support_pred=True, per_doc=False, add_trivial_spans=False): """Build a text report showing the PARSEVAL discourse metrics. @@ -299,6 +305,10 @@ def parseval_report(ctree_true, ctree_pred, exclude_root=False, report = fmt % tuple(headers) report += '\n' + # display percentages + if percent: + digits = digits - 2 + # compute scores metric_scores = dict() for metric_type, lbl_fn in lbl_fns: @@ -311,11 +321,14 @@ def parseval_report(ctree_true, ctree_pred, exclude_root=False, metric_scores[metric_type] = (p, r, f1, s_true, s_pred) # fill report + if percent: + digits = digits - 2 for metric_type in metric_types: (p, r, f1, s_true, s_pred) = metric_scores[metric_type] values = [metric_type] for v in (p, r, f1): - values += ["{0:0.{1}f}".format(v, digits)] + values += ["{0:0.{1}f}".format(v * 100.0 if percent else v, + digits)] values += ["{0}".format(s_true)] # support_true values += ["{0}".format(s_pred)] # support_pred report += fmt % tuple(values) @@ -327,7 +340,7 @@ def parseval_detailed_report(ctree_true, ctree_pred, exclude_root=False, subtree_filter=None, lbl_fn=None, span_type='edus', labels=None, sort_by_support=True, - digits=4, per_doc=False): + digits=4, percent=False, per_doc=False): """Build a text report showing the PARSEVAL discourse metrics. FIXME model after sklearn.metrics.classification.classification_report @@ -395,11 +408,15 @@ class (or micro-averaged over all classes). if sort_by_support: sorted_ilbls = sorted(sorted_ilbls, key=lambda x: s_true[x[0]], reverse=True) + # display percentages + if percent: + digits = digits - 2 # one line per label for i, label in sorted_ilbls: values = [label] for v in (p[i], r[i], f1[i]): - values += ["{0:0.{1}f}".format(v, digits)] + values += ["{0:0.{1}f}".format(v * 100.0 if percent else v, + digits)] values += ["{0}".format(s_true[i])] values += ["{0}".format(s_pred[i])] report += fmt % tuple(values) @@ -411,7 +428,8 @@ class (or micro-averaged over all classes). for v in (np.average(p, weights=s_true), np.average(r, weights=s_true), np.average(f1, weights=s_true)): - values += ["{0:0.{1}f}".format(v, digits)] + values += ["{0:0.{1}f}".format(v * 100.0 if percent else v, + digits)] values += ['{0}'.format(np.sum(s_true))] values += ['{0}'.format(np.sum(s_pred))] report += fmt % tuple(values) diff --git a/educe/rst_dt/metrics/rst_parseval.py b/educe/rst_dt/metrics/rst_parseval.py index c560c50..c73a292 100644 --- a/educe/rst_dt/metrics/rst_parseval.py +++ b/educe/rst_dt/metrics/rst_parseval.py @@ -83,6 +83,7 @@ def rst_parseval_compact_report(ctree_true, parser_preds, ctree_type='RST', subtree_filter=None, span_type='edus', metric_types=None, digits=4, + percent=False, print_support=True, per_doc=False, add_trivial_spans=False, @@ -174,6 +175,7 @@ def rst_parseval_compact_report(ctree_true, parser_preds, lbl_fns=lbl_fns, span_type=span_type, digits=digits, + percent=percent, print_support=print_support, per_doc=per_doc, add_trivial_spans=add_trivial_spans) @@ -182,7 +184,7 @@ def rst_parseval_compact_report(ctree_true, parser_preds, def rst_parseval_report(ctree_true, ctree_pred, ctree_type='RST', subtree_filter=None, metric_types=None, span_type='edus', - digits=4, print_support_pred=True, + digits=4, percent=False, print_support_pred=True, per_doc=False, add_trivial_spans=False, stringent=False): @@ -270,6 +272,7 @@ def rst_parseval_report(ctree_true, ctree_pred, ctree_type='RST', subtree_filter=subtree_filter, lbl_fns=lbl_fns, span_type=span_type, digits=digits, + percent=percent, print_support_pred=print_support_pred, per_doc=per_doc, add_trivial_spans=add_trivial_spans) @@ -279,7 +282,7 @@ def rst_parseval_detailed_report(ctree_true, ctree_pred, ctree_type='RST', subtree_filter=None, span_type='edus', metric_type='R', labels=None, sort_by_support=True, - digits=4, per_doc=False): + digits=4, percent=False, per_doc=False): """Build a text report showing the PARSEVAL discourse metrics per label. Metrics are calculated globally (average='micro'). @@ -346,4 +349,4 @@ def rst_parseval_detailed_report(ctree_true, ctree_pred, ctree_type='RST', subtree_filter=subtree_filter, span_type=span_type, lbl_fn=lbl_fn, labels=labels, sort_by_support=sort_by_support, - digits=digits, per_doc=per_doc) + digits=digits, percent=percent, per_doc=per_doc) From c7044d5312d17a700ec226bda9d209f1bba544c6 Mon Sep 17 00:00:00 2001 From: moreymat Date: Wed, 17 May 2017 11:41:45 +0200 Subject: [PATCH 31/44] ENH rst_dt.annotation._binarize() param branching --- educe/rst_dt/annotation.py | 36 +++++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/educe/rst_dt/annotation.py b/educe/rst_dt/annotation.py index 82b5df4..1891d5a 100644 --- a/educe/rst_dt/annotation.py +++ b/educe/rst_dt/annotation.py @@ -711,7 +711,7 @@ def is_binary(tree): return all(is_binary(x) for x in tree) -def _binarize(tree): +def _binarize(tree, branching='right_mixed'): """ Slightly rearrange an RST tree as a binary tree. The non-trivial cases here are @@ -730,14 +730,28 @@ def _binarize(tree): For example, given `X(List:N1, List:N2, List:N3)`, we would return `X(List:N1, List:N(List:N2, List:N3))` + + Parameters + ---------- + branching : str, one of {'left', 'right', 'right_mixed'} + Direction of the branching ; defaults to 'right_mixed', which + transforms n-ary multinuclear relations to a cascade of + right-branching binary trees, and SNS n-ary nodes into + left-branching binary trees. """ + branching_vals = ('left', 'right', 'right_mixed') + if branching not in branching_vals: + raise ValueError("branching must be one of {{}}".format( + branching_vals)) + if isinstance(tree, EDU): return tree elif len(tree) == 1 and not isinstance(tree[0], EDU): raise RSTTreeException("Ill-formed RST tree? Unary non-terminal: " + str(tree)) elif len(tree) <= 2: - return RSTTree(treenode(tree), [_binarize(x) for x in tree], + return RSTTree(treenode(tree), + [_binarize(x, branching=branching) for x in tree], origin=tree.origin) else: # convenient string representation of what the children look like @@ -755,13 +769,21 @@ def _binarize(tree): elif len(nuclei) > 1: # multi-nuclear chain if satellites: raise Exception("Multinuclear with satellites:\n%s" % tree) - kids = [_binarize(x) for x in tree] - left = kids[0] - right = _chain_to_binary(treenode(left).rel, kids[1:]) + kids = [_binarize(x, branching=branching) for x in tree] + if branching in ('right', 'right_mixed'): # right-branching + left = kids[0] + right = _chain_to_binary(treenode(left).rel, kids[1:]) + else: # left-branching + right = kids[-1] + left = _chain_to_binary(treenode(right).rel, kids[:-1]) return RSTTree(treenode(tree), [left, right], origin=tree.origin) elif nscode == 'SNS': - left = _chain_to_binary('span', tree[:2]) - right = _binarize(tree[2]) + if branching in ('left', 'right_mixed'): # left-branching + left = _chain_to_binary('span', tree[:2]) + right = _binarize(tree[2], branching=branching) + else: # right-branching + left = _binarize(tree[0], branching=branching) + right = _chain_to_binary('span', tree[1:]) return RSTTree(treenode(tree), [left, right], origin=tree.origin) else: raise RSTTreeException( From 695d40cf2be9c643ad13eb042c492287eb1b801a Mon Sep 17 00:00:00 2001 From: moreymat Date: Wed, 17 May 2017 16:21:08 +0200 Subject: [PATCH 32/44] ENH compact_report: parser_true --- educe/metrics/parseval.py | 15 ++++++++++++--- educe/rst_dt/metrics/rst_parseval.py | 4 ++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/educe/metrics/parseval.py b/educe/metrics/parseval.py index bf38240..222a16e 100644 --- a/educe/metrics/parseval.py +++ b/educe/metrics/parseval.py @@ -150,7 +150,7 @@ def parseval_scores(ctree_true, ctree_pred, subtree_filter=None, return p, r, f1, s_true, s_pred, labels -def parseval_compact_report(ctree_true, parser_preds, +def parseval_compact_report(parser_true, parser_preds, exclude_root=False, subtree_filter=None, lbl_fns=None, span_type='edus', @@ -169,8 +169,9 @@ def parseval_compact_report(ctree_true, parser_preds, Parameters ---------- - ctree_true: TODO - TODO + parser_true: str + Name of the parser used as reference ; it needs to be in the + keys of parser_preds. parser_preds: list of (parser_name, ctree_pred) Predicted c-trees for each parser. @@ -217,6 +218,14 @@ def parseval_compact_report(ctree_true, parser_preds, if percent: digits = digits - 2 + # find _true + for parser_name, ctree_pred in parser_preds: + if parser_name == parser_true: + ctree_true = ctree_pred + break + else: + raise ValueError('Unable to find reference c-trees') + for parser_name, ctree_pred in parser_preds: values = [parser_name] # compute scores diff --git a/educe/rst_dt/metrics/rst_parseval.py b/educe/rst_dt/metrics/rst_parseval.py index c73a292..4e50479 100644 --- a/educe/rst_dt/metrics/rst_parseval.py +++ b/educe/rst_dt/metrics/rst_parseval.py @@ -79,7 +79,7 @@ def rst_parseval_scores(ctree_true, ctree_pred, lbl_fn, subtree_filter=None, labels=labels, average=average) -def rst_parseval_compact_report(ctree_true, parser_preds, +def rst_parseval_compact_report(parser_true, parser_preds, ctree_type='RST', subtree_filter=None, span_type='edus', metric_types=None, digits=4, @@ -169,7 +169,7 @@ def rst_parseval_compact_report(ctree_true, parser_preds, lbl_fns = [(metric_type, metric2lbl_fn[metric_type]) for metric_type in metric_types] - return parseval_compact_report(ctree_true, parser_preds, + return parseval_compact_report(parser_true, parser_preds, exclude_root=exclude_root, subtree_filter=subtree_filter, lbl_fns=lbl_fns, From 605025234c68c69d5d1723fc4f4704b9530f55f2 Mon Sep 17 00:00:00 2001 From: moreymat Date: Wed, 17 May 2017 16:50:47 +0200 Subject: [PATCH 33/44] DOC minor fix --- educe/rst_dt/metrics/rst_parseval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/educe/rst_dt/metrics/rst_parseval.py b/educe/rst_dt/metrics/rst_parseval.py index 4e50479..6ac7497 100644 --- a/educe/rst_dt/metrics/rst_parseval.py +++ b/educe/rst_dt/metrics/rst_parseval.py @@ -97,8 +97,8 @@ def rst_parseval_compact_report(parser_true, parser_preds, Parameters ---------- - ctree_true: TODO - TODO + parser_true: str + Name of the parser used as a ref. parser_preds: List of (parser_name, List of ctree_pred) List of predictions for each parser. From fe4c0c25d0bb322ff7a7d1b6fc0c8d203223139f Mon Sep 17 00:00:00 2001 From: moreymat Date: Thu, 18 May 2017 12:06:34 +0200 Subject: [PATCH 34/44] ENH parseval similarity matrix --- educe/metrics/parseval.py | 133 +++++++++++++++++++++++++++ educe/rst_dt/metrics/rst_parseval.py | 104 ++++++++++++++++++++- 2 files changed, 236 insertions(+), 1 deletion(-) diff --git a/educe/metrics/parseval.py b/educe/metrics/parseval.py index 222a16e..87c072c 100644 --- a/educe/metrics/parseval.py +++ b/educe/metrics/parseval.py @@ -260,6 +260,139 @@ def parseval_compact_report(parser_true, parser_preds, return report +def parseval_similarity(parser_preds, + exclude_root=False, subtree_filter=None, + lbl_fn=None, + span_type='edus', + digits=4, + percent=False, + print_support=True, + per_doc=False, + add_trivial_spans=False, + out_format='str'): + """Build a similarity matrix showing the F1-scores of a PARSEVAL metric + for a list of parsers. + + Metrics are calculated globally (average='micro'), unless per_doc is + True (macro-averaging across documents). + + Parameters + ---------- + parser_preds : list of (parser_name, ctree_pred) + Predicted c-trees for each parser. + + exclude_root : TODO + TODO + + subtree_filter : TODO + TODO + + lbl_fn : (str, function) + Metric on which the similarity is computed. + + span_type : TODO + TODO + + digits : int, defaults to 4 + Number of decimals to print. + + percent : TODO + TODO + + print_support : TODO + TODO + + per_doc : boolean, defaults to False + If True, compute p, r, f for each doc separately then compute the + mean of each score over docs. This is *not* the correct + implementation, but it corresponds to that in DPLP. + + add_trivial_spans : TODO + TODO + + out_format : str, one of {'str', 'latex'} + Output format. + """ + if lbl_fn is None: + # we require a labelled span to be a pair (span, lbl) + # where span and lbl can be anything, for example + # * span = (span_beg, span_end) + # * lbl = (nuc, rel) + lbl_fn = ('Labelled Span', lambda span_lbl: span_lbl[1]) + + metric_type = lbl_fn[0] + + # prepare scaffold for report + width = max(len(parser_name) for parser_name, _ in parser_preds) + headers = [k[:7] for k, v in parser_preds] + if print_support: + headers += ["support"] + fmt = '%% %ds' % width # first col: parser name + if out_format == 'str': + fmt += ' ' + fmt += ' '.join(['% 9s' for _ in headers]) + elif out_format == 'latex': + fmt += ' &' + fmt += '&'.join(['% 9s' for _ in headers]) + fmt += '\\\\' # print "\\" + else: + raise ValueError("Unknown value for out_format: {}".format( + out_format)) + fmt += '\n' + headers = [""] + headers + + report = "" + if out_format == 'latex': + report += '\n'.join([ + '\\begin{table}[h]', + '\\begin{center}', + '\\begin{tabular}{' + 'l' * len(headers) +'}', + '\\toprule' + ]) + report += fmt % tuple(headers) + report += '\n' + if out_format == 'latex': + report += '\\midrule\n' + + # display percentages + if percent: + digits = digits - 2 + + for parser_true, ctree_true in parser_preds: + values = [parser_true] + for parser_name, ctree_pred in parser_preds: + # compute scores + p, r, f1, s_true, s_pred, labels = parseval_scores( + ctree_true, ctree_pred, subtree_filter=subtree_filter, + exclude_root=exclude_root, lbl_fn=lbl_fn[1], labels=None, + span_type=span_type, + average='micro', per_doc=per_doc, + add_trivial_spans=add_trivial_spans) + # fill report + values += ["{0:0.{1}f}".format(f1 * 100.0 if percent else f1, + digits)] + # store support in _true, for optional display below + support = s_true + + # append support + if print_support: + values += ["{0:.0f}".format(support)] # support_true + + report += fmt % tuple(values) + + if out_format == 'latex': + report += '\n'.join([ + '\\bottomrule', + '\\end{tabular}', + '\\end{center}', + '\\caption{\\label{ctree-sim} Similarity matrix on parsers predictions against non-binarized trees.}', + '\\end{table}' + ]) + report = report.replace('_', '\_') + + return report + + def parseval_report(ctree_true, ctree_pred, exclude_root=False, subtree_filter=None, lbl_fns=None, span_type='edus', digits=4, percent=False, diff --git a/educe/rst_dt/metrics/rst_parseval.py b/educe/rst_dt/metrics/rst_parseval.py index 6ac7497..5155ea3 100644 --- a/educe/rst_dt/metrics/rst_parseval.py +++ b/educe/rst_dt/metrics/rst_parseval.py @@ -11,7 +11,8 @@ from educe.metrics.parseval import (parseval_scores, parseval_report, parseval_compact_report, - parseval_detailed_report) + parseval_detailed_report, + parseval_similarity) # label extraction functions @@ -181,6 +182,107 @@ def rst_parseval_compact_report(parser_true, parser_preds, add_trivial_spans=add_trivial_spans) +def rst_parseval_similarity(parser_preds, + ctree_type='RST', + subtree_filter=None, + span_type='edus', + metric_type='S', digits=4, + percent=False, + print_support=True, + per_doc=False, + add_trivial_spans=False, + stringent=False, + out_format='str'): + """Build a similarity matrix showing the f1-scores of a PARSEVAL + discourse metric. + + This is the simplest report we need to generate, it corresponds + to the arrays of results from the literature. + Metrics are calculated globally (average='micro'). + + Parameters + ---------- + parser_preds : List of (parser_name, List of ctree_pred) + List of predictions for each parser. + + ctree_type : one of {'RST', 'SimpleRST'}, defaults to 'RST' + Type of ctrees considered in the evaluation procedure. + 'RST' is the standard type of ctrees used in the RST corpus, + it triggers the exclusion of the root node from the evaluation + but leaves are kept. + 'SimpleRST' is a binarized variant of RST trees where each + internal node corresponds to an attachment decision ; in other + words, it is a binary ctree where the nuclearity and relation label + are moved one node up compared to the standard RST trees. This + triggers the exclusion of leaves from the eval, but the root node + is kept. + + subtree_filter : function, optional + Function to filter all local trees. + + metric_type : str + Metric on which similarity is judged, defaults to 'S'. + + digits : int, defaults to 4 + Number of decimals to print. + + print_support : boolean, defaults to True + If True, the true support, i.e. the number of reference spans, + is also displayed. This is useful to control whether the + reference ctrees have been binarized. + + per_doc : boolean, defaults to False + If True, compute p, r, f for each doc separately then compute the + mean of each score over docs. This is *not* the correct + implementation, but it corresponds to that in DPLP. + + add_trivial_spans : boolean, defaults to False + If True, trivial spans 0-0, 0-n, 1-n are added ; this is meant to + replicate the evaluation procedure of Li et al.'s dependency RST + parser. + + stringent : boolean, defaults to False + TODO + + out_format : str, one of {'str', 'latex'} + Output format. + """ + # filter root or leaves, depending on the type of ctree + if ctree_type not in ['RST', 'SimpleRST']: + raise ValueError("ctree_type should be one of {'RST', 'SimpleRST'}") + if ctree_type == 'RST': + # standard RST ctree: exclude root + exclude_root = True + subtree_filter = subtree_filter + elif ctree_type == 'SimpleRST': + # SimpleRST variant: keep root, exclude leaves + exclude_root = False # TODO try True first, should get same as before + not_leaf = lambda t: t.height() > 2 # TODO unit test! + if subtree_filter is None: + subtree_filter = not_leaf + else: + subtree_filter = lambda t: not_leaf(t) and subtree_filter(t) + + # select metrics and the corresponding functions + metric_types = ['S', 'N', 'R', 'F'] + if metric_type not in metric_types: + raise ValueError("metric_type must be one of {}".format(metric_types)) + metric2lbl_fn = dict(LBL_FNS) + lbl_fn = (metric_type, metric2lbl_fn[metric_type]) + + return parseval_similarity(parser_preds, + exclude_root=exclude_root, + subtree_filter=subtree_filter, + lbl_fn=lbl_fn, + span_type=span_type, + digits=digits, + percent=percent, + print_support=print_support, + per_doc=per_doc, + add_trivial_spans=add_trivial_spans, + out_format=out_format) + + def rst_parseval_report(ctree_true, ctree_pred, ctree_type='RST', subtree_filter=None, metric_types=None, span_type='edus', From 7e6b709349a05f1b5f9d00d3e1f392f32db9ac93 Mon Sep 17 00:00:00 2001 From: moreymat Date: Sun, 21 May 2017 13:32:06 +0200 Subject: [PATCH 35/44] FIX educe.metrics.parseval missing newline --- educe/metrics/parseval.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/educe/metrics/parseval.py b/educe/metrics/parseval.py index 87c072c..3a3a65b 100644 --- a/educe/metrics/parseval.py +++ b/educe/metrics/parseval.py @@ -347,7 +347,8 @@ def parseval_similarity(parser_preds, '\\begin{table}[h]', '\\begin{center}', '\\begin{tabular}{' + 'l' * len(headers) +'}', - '\\toprule' + '\\toprule', + '' ]) report += fmt % tuple(headers) report += '\n' From 1752202de59e1b4a02b667917e8d88edb15f84b6 Mon Sep 17 00:00:00 2001 From: moreymat Date: Sun, 21 May 2017 13:52:34 +0200 Subject: [PATCH 36/44] FIX pairwise sim report: no underscore --- educe/metrics/parseval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/educe/metrics/parseval.py b/educe/metrics/parseval.py index 3a3a65b..321f404 100644 --- a/educe/metrics/parseval.py +++ b/educe/metrics/parseval.py @@ -389,7 +389,7 @@ def parseval_similarity(parser_preds, '\\caption{\\label{ctree-sim} Similarity matrix on parsers predictions against non-binarized trees.}', '\\end{table}' ]) - report = report.replace('_', '\_') + report = report.replace('_', ' ') return report From f462d4746908dd9ff329df12cd87fd9ba09ff41b Mon Sep 17 00:00:00 2001 From: moreymat Date: Wed, 7 Jun 2017 13:40:47 +0200 Subject: [PATCH 37/44] MAINT backport compatible changes and fixes from master, eg. doc_glob --- educe/annotation.py | 159 +++++++++++++++++------- educe/corpus.py | 39 ++++-- educe/external/corenlp.py | 14 ++- educe/external/parser.py | 6 +- educe/external/postag.py | 6 +- educe/graph.py | 15 +-- educe/pdtb/corpus.py | 14 ++- educe/pdtb/parse.py | 42 ++++--- educe/pdtb/pdtbx.py | 30 ++--- educe/pdtb/tests.py | 32 +++-- educe/pdtb/util/cmd/extract.py | 4 +- educe/rst_dt/corpus.py | 30 +++-- educe/rst_dt/learning/base.py | 3 +- educe/rst_dt/learning/doc_vectorizer.py | 56 ++++++--- educe/rst_dt/sdrt.py | 7 +- educe/stac/context.py | 2 +- educe/stac/corpus.py | 28 ++++- educe/stac/edit/cmd/insert.py | 6 +- educe/stac/edit/cmd/rewrite.py | 5 +- educe/stac/fusion.py | 50 ++++---- educe/stac/learning/cmd/res_nps.py | 10 +- educe/stac/learning/doc_vectorizer.py | 33 ++--- educe/stac/lexicon/markers.py | 16 +-- educe/stac/lexicon/pdtb_markers.py | 8 +- educe/stac/lexicon/wordclass.py | 4 +- educe/stac/postag.py | 3 +- educe/stac/util/cmd/count_rfc.py | 4 +- educe/tests.py | 2 +- setup.py | 1 - 29 files changed, 396 insertions(+), 233 deletions(-) diff --git a/educe/annotation.py b/educe/annotation.py index 0262e2c..e56a1dc 100644 --- a/educe/annotation.py +++ b/educe/annotation.py @@ -202,31 +202,48 @@ def __repr__(self): class Standoff(object): """A standoff object ultimately points to some piece of text. - The pointing is not necessarily direct though + The pointing is not necessarily direct though. - Parameters + Attributes ---------- - origin : educe.corpus.FileId + origin : educe.corpus.FileId, optional FileId of the document supporting this standoff. """ def __init__(self, origin=None): self.origin = origin def _members(self): - """ - Any annotations contained within this annotation. + """Any annotations contained within this annotation. Must return None if is a terminal annotation (not the same meaning as returning the empty list). Non-terminal annotations must override this. + + Returns + ------- + res : list of Standoff or None + Annotations contained within this annotation ; None for + terminal annotations. """ return None def _terminals(self, seen=None): - """ + """Terminal annotations contained within this annotation. + For terminal annotations, this is just the annotation itself. For non-terminal annotations, this recursively fetches the - terminals + terminals. + + Parameters + ---------- + seen : optional + List of already annotations that have already been seen, so + as to avoid returning duplicates. + + Returns + ------- + res : list of Standoff + List of terminal annotations for this annotation. """ my_members = self._members() if my_members is None: @@ -241,7 +258,15 @@ def text_span(self): to the latest. Corner case: if this is an empty non-terminal (which would be a very - weird thing indeed), return None + weird thing indeed), return None. + + Returns + ------- + res : Span or None + Span from the first character of the earliest terminal + annotation contained here, to the last character of the + latest terminal annotation ; None if this annotation has no + terminal. """ terminals = list(self._terminals()) if len(terminals) > 0: @@ -253,34 +278,73 @@ def text_span(self): def encloses(self, other): """ - True if this annotations's span encloses the span of the other. + True if this annotation's span encloses the span of the other. `s1.encloses(s2)` is shorthand for `s1.text_span().encloses(s2.text_span())` + + Parameters + ---------- + other : Standoff + Other annotation. + + Returns + ------- + res : boolean + True if this annotation's span encloses the span of the + other. """ return self.text_span().encloses(other.text_span()) def overlaps(self, other): """ - True if this annotations's span encloses the span of the other. + True if this annotations's span overlaps with the span of the other. `s1.overlaps(s2)` is shorthand for `s1.text_span().overlaps(s2.text_span())` + + Parameters + ---------- + other : Standoff + Other annotation. + + Returns + ------- + res : boolean + True if this annotation's span overlaps with the span of the + other. """ return self.text_span().overlaps(other.text_span()) # pylint: enable=no-self-use class Annotation(Standoff): - """ - Any sort of annotation. Annotations tend to have + """Any sort of annotation. + Annotations tend to have: * span: some sort of location (what they are annotating) * type: some key label (we call a type) * features: an attribute to value dictionary """ - def __init__(self, anno_id, span, atype, features, - metadata=None, origin=None): + def __init__(self, anno_id, span, atype, features, metadata=None, + origin=None): + """Init method. + + Parameters + ---------- + anno_id : TODO + Identifier for this annotation. + span : Span + Coordinates of the annotated span. + atype : str + Annotation type. + features : dict from str to str + Feature as a dict from feature_name to feature_value. + metadata : dict from str to str, optional + Metadata for the annotation, eg. author, creation date... + origin : FileId, optional + FileId of the document that supports this annotation. + """ Standoff.__init__(self, origin) self.origin = origin self._anno_id = anno_id @@ -298,14 +362,16 @@ def __str__(self): (self.identifier(), self.type, self.span, feats)) def local_id(self): - """ - An identifier which is sufficient to pick out this annotation within a - single annotation file + """Local identifier. + + An identifier which is sufficient to pick out this annotation + within a single annotation file. """ return self._anno_id def identifier(self): - """ + """Global identifier if possible, else local identifier. + String representation of an identifier that should be unique to this corpus at least. @@ -318,7 +384,7 @@ def identifier(self): * and the id from the XML file If we don't have an origin we fall back to just the id provided - by the XML file + by the XML file. See also `position` as potentially a safer alternative to this (and what we mean by safer) @@ -331,11 +397,14 @@ def identifier(self): class Unit(Annotation): + """Unit annotation. + + An annotation over a span of text. + """ - An annotation over a span of text - """ - def __init__(self, unit_id, span, utype, features, - metadata=None, origin=None): + + def __init__(self, unit_id, span, utype, features, metadata=None, + origin=None): Annotation.__init__(self, unit_id, span, utype, features, metadata, origin) @@ -356,13 +425,15 @@ def position(self): **position vs identifier** - This is a trade-off. One the hand, you can see the position as being - a safer way to identify a unit, because it obviates having to worry - about your naming mechanism guaranteeing stability across the board - (eg. two annotators stick an annotation in the same place; does it have - the same name). On the *other* hand, it's a bit harder to uniquely - identify objects that may coincidentally fall in the same span. So - how much do you trust your IDs? + This is a trade-off. + On the one hand, you can see the position as being a safer way + to identify a unit, because it obviates having to worry about + your naming mechanism guaranteeing stability across the board + (eg. two annotators stick an annotation in the same place; does + it have the same name). + On the *other* hand, it's a bit harder to uniquely identify + objects that may coincidentally fall in the same span. + So how much do you trust your IDs? """ if self.origin is None: ostuff = [] @@ -384,20 +455,24 @@ class Relation(Annotation): `fleshout` is called (corpus slurping normally fleshes out documents and thus their relations). - Parameters - ---------- - rel_id : string - Relation id - span : RelSpan - Pair of units connected by this relation - rtype : string - Relation type - features : dict - Features - metadata : TODO - TODO """ + def __init__(self, rel_id, span, rtype, features, metadata=None): + """Init method. + + Parameters + ---------- + rel_id : string + Relation id + span : RelSpan + Pair of units connected by this relation + rtype : string + Relation type + features : dict + Features + metadata : dict from str to str, optional + Metadata for this annotation. + """ Annotation.__init__(self, rel_id, span, rtype, features, metadata) self.source = None # to be defined in fleshout 'source annotation; will be defined by fleshout' diff --git a/educe/corpus.py b/educe/corpus.py index 23d88f5..cbcf234 100644 --- a/educe/corpus.py +++ b/educe/corpus.py @@ -54,8 +54,8 @@ def __init__(self, doc, subdoc, stage, annotator): self.annotator = annotator def __str__(self): - return "%s [%s] %s %s" % (self.doc, self.subdoc, self.stage, - self.annotator) + return "%s [%s] %s %s" % ( + self.doc, self.subdoc, self.stage, self.annotator) def _tuple(self): """ @@ -129,7 +129,7 @@ class Reader: def __init__(self, root): self.rootdir = root - def files(self): + def files(self, doc_glob=None): """ Return a dictionary from FileId to (tuples of) filepaths. The tuples correspond to files that are considered to 'belong' @@ -137,24 +137,39 @@ def files(self): the text file and its annotations Derived classes + + Parameters + ---------- + doc_glob : str, optional + Glob expression for names of game folders ; if `None`, + subclasses are expected to use the wildcard '*' that matches + all strings. """ - def slurp(self, cfiles=None, verbose=False): + def slurp(self, cfiles=None, doc_glob=None, verbose=False): """ Read the entire corpus if `cfiles` is `None` or else the subset specified by `cfiles`. Return a dictionary from FileId to `educe.Annotation.Document` - :param cfiles: a dictionary like what `Corpus.files` would return - :type cfiles: dict + Parameters + ---------- + cfiles : dict, optional + Dict of files like what `Corpus.files()` would return. + + doc_glob : str, optional + Glob pattern for doc (folder) names ; ignored if `cfiles` + is not None. - :param verbose: print what we're reading to stderr - :type verbose: bool + verbose : boolean, defaults to False + If True, print what we're reading to stderr. """ - subcorpus = (cfiles if cfiles is not None - else self.files()) - return self.slurp_subcorpus(subcorpus, verbose) + if cfiles is None: + subcorpus = self.files(doc_glob=doc_glob) + else: + subcorpus = cfiles + return self.slurp_subcorpus(subcorpus, verbose=verbose) def slurp_subcorpus(self, cfiles, verbose=False): """ @@ -166,6 +181,6 @@ def filter(self, d, pred): """ Convenience function equivalent to :: - { k: v for k, v in d.items() if pred(k) } + { k:v for k,v in d.items() if pred(k) } """ return dict([(k, v) for k, v in d.items() if pred(k)]) diff --git a/educe/external/corenlp.py b/educe/external/corenlp.py index e27746f..3f8518f 100644 --- a/educe/external/corenlp.py +++ b/educe/external/corenlp.py @@ -46,11 +46,21 @@ class CoreNlpToken(postag.Token): Attributes ---------- - features : dict(string, string) + features : dict from str to str Additional info found by corenlp about the token (eg. `x.features['lemma']`) """ def __init__(self, t, offset, origin=None): + """ + Parameters + ---------- + t : dict + Token from corenlp's XML output. + offset : int + Offset from the span of the corenlp token to the document. + origin : FileId, optional + Identifier for the document. + """ extent = t['extent'] word = t['word'] tag = t['POS'] @@ -69,7 +79,7 @@ def __repr__(self): class CoreNlpWrapper(object): - """Wrapper for the CoreNLP parsing system""" + """Wrapper for the CoreNLP parsing system.""" def __init__(self, corenlp_dir): """Setup common attributes""" diff --git a/educe/external/parser.py b/educe/external/parser.py index 9ba3d1e..3bda351 100644 --- a/educe/external/parser.py +++ b/educe/external/parser.py @@ -134,15 +134,15 @@ def build(cls, tree, tokens): ---------- tree : nltk.Tree Original NLTK tree. - tokens : iterable - List of replacement leaves. + tokens : iterable of Token + Sequence of replacement leaves. Returns ------- ctree : ConstituencyTree ConstituencyTree where the internal nodes have the same labels as in the original NLTK tree and the leaves - correspond to the given list of tokens. + correspond to the given sequence of tokens. """ toks = deque(tokens) diff --git a/educe/external/postag.py b/educe/external/postag.py index 937f725..a25a671 100644 --- a/educe/external/postag.py +++ b/educe/external/postag.py @@ -180,19 +180,17 @@ def token_spans(text, tokens, offset=0): Parameters ---------- - text : string + text : str Base text. - tokens : sequence of RawToken Sequence of raw tokens in the text. - offset : int, defaults to 0 Offset for spans. Returns ------- res : list of Token - Sequence of proper educe `Token`s with their span. + Sequence of proper educe Tokens with their span. """ token_words = [tok.word for tok in tokens] spans = generic_token_spans(text, token_words, offset) diff --git a/educe/graph.py b/educe/graph.py index 3a38869..bacfaab 100644 --- a/educe/graph.py +++ b/educe/graph.py @@ -427,8 +427,8 @@ def _attrs(self, x): elif self.has_node(x): return self.node_attributes_dict(x) else: - raise Exception('Tried to get attributes of non-existing' - ' object ' + str(x)) + raise Exception("Tried to get attributes of non-existing object" + " " + str(x)) def relations(self): """ @@ -591,8 +591,8 @@ def _repr_svg_(self): raise Exception('Cannot find the dot binary from Graphviz package') out, err = process.communicate(dot_string) if err: - raise Exception('Cannot create svg representation by running' - ' dot from string\n:%s' % dot_string) + raise Exception("Cannot create svg representation by running " + "dot from string\n:%s" % dot_string) return out @@ -834,9 +834,10 @@ def _add_complex_cdu(self, hyperedge): def __init__(self, anno_graph): """ - Args - - anno_graph (Graph): abstract annotation graph + Parameters + ---------- + anno_graph : Graph + Abstract annotation graph. """ self.core = anno_graph self.doc = self.core.doc diff --git a/educe/pdtb/corpus.py b/educe/pdtb/corpus.py index 44605da..ece7f03 100644 --- a/educe/pdtb/corpus.py +++ b/educe/pdtb/corpus.py @@ -24,9 +24,19 @@ class Reader(educe.corpus.Reader): def __init__(self, corpusdir): educe.corpus.Reader.__init__(self, corpusdir) - def files(self): + def files(self, doc_glob=None): + """ + Parameters + ---------- + doc_glob : str, optional + Glob expression for document (folder) names ; if `None`, it + uses the wildcard '*/*' for folder names and file basenames. + """ + if doc_glob is None: + doc_glob = '*/*' anno_files = {} - full_glob = os.path.join(self.rootdir, '*/*.pdtb') + full_glob = os.path.join( + self.rootdir, '{doc_glob}.pdtb'.format(doc_glob=doc_glob)) for fname in glob(full_glob): bname = os.path.basename(fname) doc = os.path.splitext(bname)[0] diff --git a/educe/pdtb/parse.py b/educe/pdtb/parse.py index d5a8f9e..5511f71 100755 --- a/educe/pdtb/parse.py +++ b/educe/pdtb/parse.py @@ -177,16 +177,18 @@ def __init__(self, selection, attribution=None, sup=None): def _substr(self): sup_str = ' + %s' % self.sup if self.sup else '' - return '%s | %s%s' % (Selection._substr(self), self.attribution, - sup_str) + return '%s | %s%s' % ( + Selection._substr(self), self.attribution, sup_str) class Relation(PdtbItem): """ - Fields: - - * self.arg1 - * self.arg2 + Attributes + ---------- + arg1 : TODO + TODO + arg2 : TODO + TODO """ def __init__(self, args): if len(args) == 4: @@ -293,10 +295,11 @@ def __init__(self, infsite, args): def _substr(self): return Relation._substr(self) + # --------------------------------------------------------------------- # not-quite-lexing # --------------------------------------------------------------------- - +# # FIXME # funcparserlib works on a stream of arbitrary tokens, eg. the output of # a lexer. At the time of this writing, I didn't trust any of the fancy @@ -306,8 +309,7 @@ def _substr(self): # the raw text bits eg. `r'#### Text ####\n(.*?)\n##############'`; and # provide some abstractions over tokens, we could maybe simplify the # parser a lot... which could in turn make it faster? - - +# class _Char(object): def __init__(self, value, abspos, line, relpos): self.value = value @@ -490,6 +492,7 @@ def _helper(tokens, s): _helper.name = u'{ literal %s }' % xs return _helper + if _DEBUG: _annotate = _annotate_debug _mkstr = _mkstr_debug @@ -503,7 +506,8 @@ def _helper(tokens, s): # --------------------------------------------------------------------- # elementary parts # --------------------------------------------------------------------- -_nat = fp.oneplus(_satisfies(lambda c: c.isdigit())) >> (lambda x: int(_mkstr(x))) +_nat = fp.oneplus(_satisfies(lambda c: c.isdigit())) >> ( + lambda x: int(_mkstr(x))) _nl = fp.skip(_oneof("\r\n")) _comma = fp.skip(_oneof(",")) _semicolon = fp.skip(_oneof(";")) @@ -618,8 +622,8 @@ def _subsection_begin(t): _SemanticClassN = _sepby(_fullstop, _SemanticClassWord) >> SemClass _SemanticClass1 = _SemanticClassN _SemanticClass2 = _SemanticClassN -_semanticClass = _SemanticClass1 + fp.maybe( - _sp + _comma + _sp + _SemanticClass2) +_semanticClass = _SemanticClass1 + fp.maybe(_sp + _comma + _sp + + _SemanticClass2) # always followed by a comma (yeah, a bit clunky) _ConnHead = _skipto_mkstr(_comma) @@ -631,9 +635,12 @@ def _mkConnective(c, semclasses): return Connective(c, *semclasses) -_connHeadSemanticClass = _ConnHead + _sp + _semanticClass >> _unarg(_mkConnective) -_conn1SemanticClass = _Conn1 + _sp + _semanticClass >> _unarg(_mkConnective) -_conn2SemanticClass = _Conn2 + _sp + _semanticClass >> _unarg(_mkConnective) +_connHeadSemanticClass = _ConnHead + _sp + _semanticClass >> _unarg( + _mkConnective) +_conn1SemanticClass = _Conn1 + _sp + _semanticClass >> _unarg( + _mkConnective) +_conn2SemanticClass = _Conn2 + _sp + _semanticClass >> _unarg( + _mkConnective) # --------------------------------------------------------------------- @@ -674,6 +681,7 @@ def _mk_args_and_sups(): sans_sup1 = _lines(rest) >> (lambda xs: tuple([None] + list(xs))) return with_sup1 | sans_sup1 # yuck :-( + _args_and_sups = _mk_args_and_sups() _args_only =\ _lines([_arg_no_features('arg1'), @@ -796,8 +804,8 @@ def parse(path): Parameters ---------- - path : string - Path to the .pdtb file + path : str + Path to the .pdtb file (?) Returns ------- diff --git a/educe/pdtb/pdtbx.py b/educe/pdtb/pdtbx.py index f4a8faa..84e64d7 100644 --- a/educe/pdtb/pdtbx.py +++ b/educe/pdtb/pdtbx.py @@ -46,8 +46,8 @@ def _read_Selection(node): attr = node.attrib return ty.Selection(span=_read_SpanList(attr['spanList']), gorn=_read_GornAddressList(attr['gornList']), - text=on_single_element(node, None, lambda x: x.text, - 'text')) + text=on_single_element( + node, None, lambda x: x.text, 'text')) def _read_InferenceSite(node): @@ -83,7 +83,8 @@ def _read_Sup(node): def _read_Arg(node): sup = on_single_element(node, (), _read_Sup, 'sup') - attribution = on_single_element(node, (), _read_Attribution, 'attribution') + attribution = on_single_element( + node, (), _read_Attribution, 'attribution') return ty.Arg(selection=_read_Selection(node), attribution=(None if attribution is () else attribution), sup=(None if sup is () else sup)) @@ -98,9 +99,10 @@ def _read_Args(node): def _read_ExplicitRelationFeatures(node): - attribution = on_single_element(node, None, _read_Attribution, - 'attribution') - connhead = on_single_element(node, None, _read_Connective, 'connhead') + attribution = on_single_element( + node, None, _read_Attribution, 'attribution') + connhead = on_single_element( + node, None, _read_Connective, 'connhead') return ty.ExplicitRelationFeatures(attribution=attribution, connhead=connhead) @@ -120,8 +122,8 @@ def _read_ImplicitRelationFeatures(node): raise EduceXmlException('Was expecting no more than two connectives ' '(got %d)' % len(connectives)) - attribution = on_single_element(node, None, _read_Attribution, - 'attribution') + attribution = on_single_element( + node, None, _read_Attribution, 'attribution') connective1 = _read_Connective(connectives[0]) connective2 = (_read_Connective(connectives[1]) if len(connectives) == 2 else None) @@ -137,8 +139,8 @@ def _read_ImplicitRelation(node): def _read_AltLexRelationFeatures(node): - attribution = on_single_element(node, None, _read_Attribution, - 'attribution') + attribution = on_single_element( + node, None, _read_Attribution, 'attribution') attr = node.attrib semclass1_ = attr['semclass1'] semclass2_ = attr.get('semclass2', None) # optional @@ -178,8 +180,8 @@ def read_Relation(node): elif tag == 'noRelation': return _read_NoRelation(node) else: - raise EduceXmlException("Don't know how to read relation with name " - "%s" % tag) + raise EduceXmlException("Don't know how to read relation with " + "name %s" % tag) def read_Relations(node): @@ -326,8 +328,8 @@ def Relation_xml(itm): elif isinstance(itm, ty.NoRelation): return _NoRelation_xml(itm) else: - raise Exception("Don't know how to translate relation of type " - "%s" % type(itm)) + raise Exception("Don't know how to translate relation of " + "type %s" % type(itm)) def Relations_xml(itms): diff --git a/educe/pdtb/tests.py b/educe/pdtb/tests.py index 7b5fd94..06ba162 100644 --- a/educe/pdtb/tests.py +++ b/educe/pdtb/tests.py @@ -88,7 +88,8 @@ def test_lines(self): parser = p._lines([char("a"), char("b"), char("c")]) >> p._mkstr self.assertParse(parser, expected, txt) - parser = p._lines([char("a"), char("b"), p._OptionalBlock(char("c"))]) >> p._mkstr + parser = p._lines( + [char("a"), char("b"), p._OptionalBlock(char("c"))]) >> p._mkstr self.assertParse(parser, expected, txt) def test_tok(self): @@ -177,8 +178,8 @@ def test_semclass(self): self.assertParse(p._semanticClass, expected, txt) def test_connective(self): - expected = p.Connective('also', p.SemClass(['Expansion', - 'Conjunction'])) + expected = p.Connective( + 'also', p.SemClass(['Expansion', 'Conjunction'])) txt = 'also, Expansion.Conjunction' self.assertParse(p._conn1SemanticClass, expected, txt) self.assertParse(p._conn2SemanticClass, expected, txt) @@ -193,25 +194,22 @@ def test_sup(self): def test_implicit_features_1(self): expected_attr = p.Attribution('Wr', 'Comm', 'Null', 'Null') - expected_conn = p.Connective('also', p.SemClass(['Expansion', - 'Conjunction'])) - expected = p.ImplicitRelationFeatures(expected_attr, expected_conn, - None) + expected_conn = p.Connective( + 'also', p.SemClass(['Expansion', 'Conjunction'])) + expected = p.ImplicitRelationFeatures( + expected_attr, expected_conn, None) txt = ex_implicit_attribution self.assertParse(p._implicitRelationFeatures, expected, txt) def test_implicit_features_2(self): - expected_conn1 = p.Connective('in particular', - p.SemClass(['Expansion', - 'Restatement', - 'Specification'])) - expected_conn2 = p.Connective('because', - p.SemClass(['Contingency', - 'Cause', - 'Reason'])) + expected_conn1 = p.Connective( + 'in particular', p.SemClass(['Expansion', 'Restatement', + 'Specification'])) + expected_conn2 = p.Connective( + 'because', p.SemClass(['Contingency', 'Cause', 'Reason'])) expected_attr = p.Attribution('Wr', 'Comm', 'Null', 'Null') - expected = p.ImplicitRelationFeatures(expected_attr, expected_conn1, - expected_conn2) + expected = p.ImplicitRelationFeatures( + expected_attr, expected_conn1, expected_conn2) txt = ex_implicit_features self.assertParse(p._implicitRelationFeatures, expected, txt) diff --git a/educe/pdtb/util/cmd/extract.py b/educe/pdtb/util/cmd/extract.py index 8070a83..3a0789c 100644 --- a/educe/pdtb/util/cmd/extract.py +++ b/educe/pdtb/util/cmd/extract.py @@ -5,14 +5,14 @@ Extract features 2017-01-27 this code is broken ; it relies on stac.keys.KeyGroupWriter -which was deprecated and removed a while back. +which was deprecated and removed a while back (MM to self: way to go!). """ import codecs import csv import os -import stac.csv +import stac.util.stac_csv_format import stac.keys from ..args import\ diff --git a/educe/rst_dt/corpus.py b/educe/rst_dt/corpus.py index faf0cfd..61f7f1f 100644 --- a/educe/rst_dt/corpus.py +++ b/educe/rst_dt/corpus.py @@ -5,11 +5,9 @@ Corpus management (re-exported by educe.rst_dt) """ +from glob import glob import os import sys -from glob import glob -from os.path import dirname -from os.path import join from nltk import Tree @@ -38,20 +36,25 @@ class Reader(educe.corpus.Reader): def __init__(self, corpusdir): educe.corpus.Reader.__init__(self, corpusdir) - def files(self, exclude_file_docs=False): + def files(self, doc_glob=None): """ Parameters ---------- - exclude_file_docs : boolean, optional (default=False) - If True, fileX documents are ignored. The figures reported by - (Li et al., 2014) on the RST-DT corpus indicate they exclude - fileN files, whereas Joty seems to include them. - fileN documents are more damaged than wsj_XX documents, e.g. - text mismatches with the corresponding document in the PTB. + doc_glob : str, optional + Glob for document names, ie. file basenames. A common + pattern is `doc_glob='wsj_*'` to exclude documents whose + file basenames are of the form `fileX`. + `fileX` documents are damaged compared to `wsj_XX` documents + ie. their text and that of the corresponding document in the + PTB mismatch, and text formatting is scrambled. For example, + the figures reported in the paper of (Li et al., 2014) + indicate they only consider `wsj_XX` files. """ + if doc_glob is None: + doc_glob = '*' anno_files = {} - dis_glob = 'wsj_*.dis' if exclude_file_docs else '*.dis' - full_glob = os.path.join(self.rootdir, dis_glob) + full_glob = os.path.join( + self.rootdir, '{doc_glob}.dis'.format(doc_glob=doc_glob)) for fname in glob(full_glob): text_file = os.path.splitext(fname)[0] @@ -148,7 +151,8 @@ def __init__(self, corpus_dir, args, coarse_rels=False, exclude_file_docs=False): self.reader = Reader(corpus_dir) # pre-load corpus - anno_files_unfltd = self.reader.files(exclude_file_docs) + doc_glob = 'wsj_*' if exclude_file_docs else None + anno_files_unfltd = self.reader.files(doc_glob=doc_glob) is_interesting = educe.util.mk_is_interesting(args) anno_files = self.reader.filter(anno_files_unfltd, is_interesting) self.corpus = self.reader.slurp(anno_files, verbose=True) diff --git a/educe/rst_dt/learning/base.py b/educe/rst_dt/learning/base.py index 07f9733..19b57ca 100644 --- a/educe/rst_dt/learning/base.py +++ b/educe/rst_dt/learning/base.py @@ -265,8 +265,7 @@ def preprocess(self, doc, strict=False): (toks_beg < para_end, toks_end > para_beg, toks_beg >= tree_beg, - toks_end <= tree_end) - ) + toks_end <= tree_end)) )[0] overtoks = [tokens[i] for i in overtoks_idc] syn_node_seq = syntactic_node_seq( diff --git a/educe/rst_dt/learning/doc_vectorizer.py b/educe/rst_dt/learning/doc_vectorizer.py index 073daf3..dd9659f 100644 --- a/educe/rst_dt/learning/doc_vectorizer.py +++ b/educe/rst_dt/learning/doc_vectorizer.py @@ -13,21 +13,6 @@ class DocumentLabelExtractor(object): """Label extractor for the RST-DT treebank. - Parameters - ---------- - instance_generator : generator - Generator that enumerates the instances from a doc. - - ordered_pairs : boolean (default: True) - True if the generated instances are ordered pairs of DUs: - (du1, du2) != (du2, du1). - - unknown_label : str - Reserved label for unknown cases. - - labelset : TODO - TODO - Attributes ---------- fixed_labelset_ : boolean @@ -42,6 +27,19 @@ def __init__(self, instance_generator, ordered_pairs=True, unknown_label='__UNK__', labelset=None): + """ + Parameters + ---------- + instance_generator : generator + Generator that enumerates the instances from a doc. + ordered_pairs : boolean (default: True) + True if the generated instances are ordered pairs of DUs: + (du1, du2) != (du2, du1). + unknown_label : str, defaults to __UNK__ + Reserved label for unknown cases. + labelset : TODO + TODO + """ self.instance_generator = instance_generator self.ordered_pairs = ordered_pairs # 2016-09-30 self.unknown_label = unknown_label @@ -115,12 +113,12 @@ def decode(self, doc): Parameters ---------- - doc: DocumentPlus + doc : DocumentPlus Rich representation of the document. Returns ------- - doc: DocumentPlus + doc : DocumentPlus Rich representation of `doc`. """ if not isinstance(doc, DocumentPlus): @@ -252,12 +250,12 @@ def _extract_feature_vectors(self, doc): Parameters ---------- - doc: educe.rst_dt.document_plus.DocumentPlus + doc : educe.rst_dt.document_plus.DocumentPlus Rich representation of the document. Returns ------- - feat_vecs: list of feature vectors + feat_vecs : list of feature vectors List of feature vectors, one for every pair of EDUs (in the order in which they are generated by `self.instance_generator()`). @@ -458,6 +456,26 @@ def _limit_features(self, vocab_df, vocabulary, high=None, low=None, This is essentially a reimplementation of the one in sklearn.feature_extraction.text.CountVectorizer, except vocab_df is computed differently. + + Parameters + ---------- + vocab_df : TODO + TODO + vocabulary : TODO + TODO + high : TODO + TODO + low : TODO + TODO + limit : TODO + TODO + + Returns + ------- + vocabulary : dict from string to int + Vocabulary. + removed_feats : set + Set of removed features. """ if high is None and low is None and limit is None: return set() diff --git a/educe/rst_dt/sdrt.py b/educe/rst_dt/sdrt.py index 2d91b42..a3c9e4c 100644 --- a/educe/rst_dt/sdrt.py +++ b/educe/rst_dt/sdrt.py @@ -23,7 +23,8 @@ class CDU: """Complex Discourse Unit. A CDU contains one or more discourse units, and tracks relation - instances between its members. Both CDU and EDU are discourse units. + instances between its members. + Both CDU and EDU are discourse units. Attributes ---------- @@ -194,8 +195,8 @@ def rst_to_sdrt(tree): nuclei = [x for x in tree if x.label().is_nucleus()] satellites = [x for x in tree if x.label().is_satellite()] if len(nuclei) + len(satellites) != len(tree): - raise ValueError( - "Nodes that are neither Nuclei nor Satellites\n%s" % tree) + raise ValueError("Nodes that are neither Nuclei nor " + "Satellites\n%s" % tree) if len(nuclei) == 0: raise ValueError("No nucleus:\n%s" % tree) diff --git a/educe/stac/context.py b/educe/stac/context.py index d2f3b90..c3cb97a 100644 --- a/educe/stac/context.py +++ b/educe/stac/context.py @@ -220,7 +220,7 @@ def for_edus(cls, doc, postags=None): Returns ------- contexts: dict(educe.glozz.Unit, Context) - A dictionary with a context For each EDU in the document + A dictionary with a context for each EDU in the document. """ if postags: egraph = EnclosureGraph(doc, postags) diff --git a/educe/stac/corpus.py b/educe/stac/corpus.py index 28abd03..6ae4ded 100644 --- a/educe/stac/corpus.py +++ b/educe/stac/corpus.py @@ -27,9 +27,19 @@ class Reader(educe.corpus.Reader): def __init__(self, corpusdir): educe.corpus.Reader.__init__(self, corpusdir) - def files(self): + def files(self, doc_glob=None): + """Gather files for docs whose folder name matches `doc_glob`. + + Parameters + ---------- + doc_glob : str, optional + Glob expression for document (folder) names ; if `None`, + it uses the wildcard '*' to match all strings. + """ + if doc_glob is None: + doc_glob = '*' corpus = OrderedDict() - full_glob = os.path.join(self.rootdir, '*') + full_glob = os.path.join(self.rootdir, doc_glob) anno_glob = '*.aa' def register(stage, annotator, anno_file): @@ -100,9 +110,19 @@ class LiveInputReader(Reader): def __init__(self, corpusdir): Reader.__init__(self, corpusdir) - def files(self): + def files(self, doc_glob=None): + """ + Parameters + ---------- + doc_glob : str, optional + Glob expression for document (folder) names ; if `None`, it + uses the wildcard '*' for file basenames. + """ + if doc_glob is None: + doc_glob = '*' corpus = {} - for anno_file in glob(os.path.join(self.rootdir, '*.aa')): + for anno_file in glob(os.path.join( + self.rootdir, '{doc_glob}.aa'.format(doc_glob=doc_glob))): prefix = os.path.splitext(anno_file)[0] pair = (anno_file, prefix + '.ac') k = educe.corpus.FileId(doc=os.path.basename(prefix), diff --git a/educe/stac/edit/cmd/insert.py b/educe/stac/edit/cmd/insert.py index eabae7b..2bcc5ee 100644 --- a/educe/stac/edit/cmd/insert.py +++ b/educe/stac/edit/cmd/insert.py @@ -11,8 +11,10 @@ import educe.stac from educe.stac.util.annotate import show_diff -from educe.stac.util.args import (add_usual_input_args, add_usual_output_args, - announce_output_dir, get_output_dir) +from educe.stac.util.args import ( + add_usual_input_args, add_usual_output_args, announce_output_dir, + get_output_dir +) from educe.stac.util.output import save_document from educe.stac.util.doc import compute_renames, move_portion from .move import is_requested diff --git a/educe/stac/edit/cmd/rewrite.py b/educe/stac/edit/cmd/rewrite.py index be00d25..1540f0f 100644 --- a/educe/stac/edit/cmd/rewrite.py +++ b/educe/stac/edit/cmd/rewrite.py @@ -8,8 +8,9 @@ import copy -from educe.stac.util.args import (add_usual_input_args, read_corpus, - get_output_dir, announce_output_dir) +from educe.stac.util.args import ( + add_usual_input_args, read_corpus, get_output_dir, announce_output_dir +) from educe.stac.util.output import save_document from educe.stac.context import sorted_first_widest diff --git a/educe/stac/fusion.py b/educe/stac/fusion.py index 55f54f2..e883950 100644 --- a/educe/stac/fusion.py +++ b/educe/stac/fusion.py @@ -41,24 +41,23 @@ class Dialogue(object): - """STAC Dialogue + """STAC Dialogue. Note that input EDUs should be sorted by span. - Parameters - ---------- - anno : educe.stac.annotation.Unit - Glozz annotation corresponding to the dialogue ; only its - identifier is stored, currently. - - edus : list(educe.stac.annotation.Unit) - List of EDU annotations, sorted by their span. - - relations : list(educe.stac.annotation.Relation - List of relations between EDUs from the dialogue. - """ def __init__(self, anno, edus, relations): + """ + Parameters + ---------- + anno : educe.stac.annotation.Unit + Glozz annotation corresponding to the dialogue ; only its + identifier is stored, currently. + edus : list of educe.stac.annotation.Unit + List of EDU annotations, sorted by their span. + relations : list of educe.stac.annotation.Relation + List of relations between EDUs from the dialogue. + """ self.grouping = anno.identifier() self.edus = [FakeRootEDU] + edus self.relations = relations @@ -74,7 +73,7 @@ def edu_pairs(self): Yields ------ - (source, target) : tuple(educe.stac.annotation.Unit) + (source, target) : tuple of educe.stac.annotation.Unit Next candidate edge, as a pair of EDUs (source, target). """ i_edus = list(enumerate(self.edus)) @@ -83,7 +82,6 @@ def edu_pairs(self): for _, edu in i_edus: yield (fakeroot, edu) # generate all pairs of (real) EDUs - # real_pairs = [] # DEBUG for num1, edu1 in i_edus: def is_before(numedu2): 'true if we have seen the EDU already' @@ -94,16 +92,6 @@ def is_before(numedu2): for _, edu2 in itertools.dropwhile(is_before, i_edus): yield (edu1, edu2) yield (edu2, edu1) - # DEBUG - # real_pairs.append((edu1, edu2)) - # real_pairs.append((edu2, edu1)) - # end DEBUG - # DEBUG compare list of EDU pairs from the above loop with a - # one-liner - # real_pairs_itr = sorted(itertools.permutations(self.edus[1:])) - # assert real_pairs_itr != sorted(real_pairs) - # raise ValueError("woooop") - # end DEBUG # pylint: disable=too-many-instance-attributes @@ -119,6 +107,16 @@ class EDU(Unit): annotations and contexts """ def __init__(self, doc, discourse_anno, unit_anno): + """ + Parameters + ---------- + doc : ? + ? + discourse_anno : ? + Annotation from the discourse layer. + unit_anno : ? + Annotation from the units layer. + """ self._doc = doc self._anno = discourse_anno self._unit_anno = unit_anno @@ -274,10 +272,8 @@ def fuse_edus(discourse_doc, unit_doc, postags): ---------- discourse_doc : GlozzDocument Document from the "discourse" stage. - unit_doc : GlozzDocument Document from the "units" stage. - postags : list of Token Sequence of educe tokens predicted by the POS tagger for this document. diff --git a/educe/stac/learning/cmd/res_nps.py b/educe/stac/learning/cmd/res_nps.py index 848c747..5d4e11d 100644 --- a/educe/stac/learning/cmd/res_nps.py +++ b/educe/stac/learning/cmd/res_nps.py @@ -15,11 +15,13 @@ from educe.stac import postag, corenlp from educe.stac.annotation import is_edu -from educe.stac.learning.doc_vectorizer import (mk_env, get_players, - FeatureInput, LexWrapper) +from educe.stac.learning.doc_vectorizer import ( + mk_env, get_players, FeatureInput, LexWrapper +) from educe.stac.learning.features import enclosed_trees, is_nplike -from educe.util import (add_corpus_filters, fields_without, mk_is_interesting, - concat, concat_l) +from educe.util import ( + add_corpus_filters, fields_without, mk_is_interesting, concat, concat_l +) import educe.corpus import educe.glozz import educe.learning.keys diff --git a/educe/stac/learning/doc_vectorizer.py b/educe/stac/learning/doc_vectorizer.py index e7a6243..224bfe8 100644 --- a/educe/stac/learning/doc_vectorizer.py +++ b/educe/stac/learning/doc_vectorizer.py @@ -72,21 +72,6 @@ def transform(self, raw_documents): class LabelVectorizer(object): """Label extractor for the STAC corpus. - Parameters - ---------- - instance_generator : fun(doc) -> :obj:`list` of (EDU, EDU) - Function to enumerate the candidate instances from a doc. - - labels : :obj:`list` of str - Set of domain labels. If it is provided as a sequence, the - order of labels is preserved ; otherwise its elements are - sorted before storage. This guarantees a stable behaviour - across runs and platforms, which greatly facilitates - the comparability of models and results. - - zero : boolean, defaults to False - If True, transform() will return the unknown label (UNK) for - all (MM or only unrelated?) pairs. Attributes ---------- @@ -96,6 +81,24 @@ class LabelVectorizer(object): """ def __init__(self, instance_generator, labels, zero=False): + """ + + Parameters + ---------- + instance_generator : fun(doc) -> :obj:`list` of (EDU, EDU) + Function to enumerate the candidate instances from a doc. + + labels : :obj:`list` of str + Set of domain labels. If it is provided as a sequence, the + order of labels is preserved ; otherwise its elements are + sorted before storage. This guarantees a stable behaviour + across runs and platforms, which greatly facilitates + the comparability of models and results. + + zero : boolean, defaults to False + If True, transform() will return the unknown label (UNK) for + all (MM or only unrelated?) pairs. + """ self.instance_generator = instance_generator if not isinstance(labels, Sequence): labels = sorted(labels) diff --git a/educe/stac/lexicon/markers.py b/educe/stac/lexicon/markers.py index c50ab3f..b93f115 100644 --- a/educe/stac/lexicon/markers.py +++ b/educe/stac/lexicon/markers.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- """ -Api on discourse markers (lexicon I/O mostly) +API on discourse markers (lexicon I/O mostly) """ try: import xml.etree.cElementTree as ET @@ -28,13 +28,13 @@ class Marker: version 2 has grammatical host and lemma """ def __init__(self, elmt, version="2", stop=_stopwords): - self._forms = [x.text.strip() - for x in elmt.findall(".//%s" % _table[version]["form"])] + self._forms = [x.text.strip() for x in elmt.findall( + ".//%s" % _table[version]["form"])] self.__dict__.update(elmt.attrib) # if version == "2": - self.relations = [x.attrib["relation"] - for x in elmt.findall(".//use")] + self.relations = [x.attrib["relation"] for x in elmt.findall( + ".//use")] else: self.relations = [x.strip() for x in self.relations.split(",")] self.lemma = self.forms[0] @@ -56,8 +56,8 @@ def __init__(self, infile, version="2", stop=_stopwords): """read lexconn file, version is 1 or 2 """ lex = ET.parse(infile) - markers = [Marker(x, version=version) - for x in lex.findall(".//%s" % _table[version]["marker"])] + markers = [Marker(x, version=version) for x in lex.findall( + ".//%s" % _table[version]["marker"])] markers = [x for x in markers if x.get_lemma() not in stop] self._markers = dict([(x.id, x) for x in markers]) @@ -76,5 +76,7 @@ def get_by_lemma(self, lemma): # tests if __name__ == "__main__": + import sys + infile = sys.argv[1] lex = LexConn(infile, version=sys.argv[2]) diff --git a/educe/stac/lexicon/pdtb_markers.py b/educe/stac/lexicon/pdtb_markers.py index 4faaa2f..252a2cb 100755 --- a/educe/stac/lexicon/pdtb_markers.py +++ b/educe/stac/lexicon/pdtb_markers.py @@ -96,12 +96,12 @@ def load_pdtb_markers_lexicon(filename): Parameters ---------- - filename: string - Path to the lexicon + filename : str + Path to the lexicon. Returns ------- - markers: dict(Marker, list(string)) + markers : dict(Marker, list(string)) Discourse markers and the relations they signal """ blacklist = frozenset(['\\wedge']) @@ -141,7 +141,7 @@ def read_lexicon(filename): Parameters ---------- - filename : string + filename : str Path to the lexicon. Returns diff --git a/educe/stac/lexicon/wordclass.py b/educe/stac/lexicon/wordclass.py index 51c91d1..be8d8f5 100644 --- a/educe/stac/lexicon/wordclass.py +++ b/educe/stac/lexicon/wordclass.py @@ -40,8 +40,8 @@ class LexEntry(namedtuple("LexEntry", def __new__(cls, word, lex_class, pos, subclass): pos = pos if pos != '??' else None subclass = subclass or None - return super(LexEntry, cls).__new__(cls, word, lex_class, pos, - subclass) + return super(LexEntry, cls).__new__( + cls, word, lex_class, pos, subclass) @classmethod def read_entry(cls, line): diff --git a/educe/stac/postag.py b/educe/stac/postag.py index 33155ee..d4a1390 100644 --- a/educe/stac/postag.py +++ b/educe/stac/postag.py @@ -116,8 +116,7 @@ def read_tags(corpus, root_dir): ---------- corpus : dict(FileId, GlozzDocument) Dictionary of documents keyed by their FileId. - - root_dir : string + root_dir : str Path to the directory containing the output of the POS tagger, one file per document. diff --git a/educe/stac/util/cmd/count_rfc.py b/educe/stac/util/cmd/count_rfc.py index 1b0959f..078f8c0 100644 --- a/educe/stac/util/cmd/count_rfc.py +++ b/educe/stac/util/cmd/count_rfc.py @@ -108,7 +108,7 @@ def display_violations(res): tres.append([row_name] + list(res[(col_name, table_name, row_name)] for col_name in col_names)) - print(tabulate(tres, headers=[table_name]+col_names)+'\n') + print(tabulate(tres, headers=[table_name]+col_names) + '\n') def display_power(res): @@ -129,7 +129,7 @@ def display_power(res): rfc_power = (100 * avg_frontier_size) / nb_edus row.append(rfc_power) tres.append(row) - print(tabulate(tres, headers=col_names, floatfmt='.1f') + '\n') + print(tabulate(tres, headers=col_names, floatfmt='.1f')+'\n') def config_argparser(parser): diff --git a/educe/tests.py b/educe/tests.py index 97b20a4..26a5d89 100644 --- a/educe/tests.py +++ b/educe/tests.py @@ -342,7 +342,7 @@ def test_copy(self): # including CDU should also result in members being included xset4 = set(['X2']) gr4 = gr.copy(nodeset=xset4) - self.assertEqual(xset2, gr4.edus()) + self.assertEqual(xset2, gr4.edus()) self.assertEqual(set(['X1', 'X2']), gr4.cdus()) diff --git a/setup.py b/setup.py index 80270a8..00931f6 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,6 @@ 'python-graph-core', 'python-graph-dot', 'frozendict', - 'sh', 'six', 'tabulate', 'nltk >= 3.0.0', From 48f7c26711e8622a58857cf6138e95a314c6f884 Mon Sep 17 00:00:00 2001 From: moreymat Date: Wed, 7 Jun 2017 17:20:19 +0200 Subject: [PATCH 38/44] MAINT minor cleanup --- educe/annotation.py | 2 +- educe/learning/keygroup_vectorizer.py | 10 ---------- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/educe/annotation.py b/educe/annotation.py index e56a1dc..6f0ef39 100644 --- a/educe/annotation.py +++ b/educe/annotation.py @@ -249,7 +249,7 @@ def _terminals(self, seen=None): if my_members is None: return [self] seen = seen or [] - return chain.from_iterable([m._terminals(seen + my_members) + return chain.from_iterable([m._terminals(seen=seen + my_members) for m in my_members if m not in seen]) def text_span(self): diff --git a/educe/learning/keygroup_vectorizer.py b/educe/learning/keygroup_vectorizer.py index c008472..9103152 100644 --- a/educe/learning/keygroup_vectorizer.py +++ b/educe/learning/keygroup_vectorizer.py @@ -5,9 +5,6 @@ # lots of scikit-conventional names here from collections import defaultdict -import sys - -import numpy as np class KeyGroupVectorizer(object): @@ -88,13 +85,6 @@ def _count_vocab(self, vectors, fixed_vocab=False): x = feature_acc[current_row:next_row] X[-1].append(x) - if False: # DEBUG - n_edus = [len(y) for y in X] - print(len(vocabulary), sys.getsizeof(vocabulary)) - print(len(X), sum(len(y) for y in X), sys.getsizeof(X)) - print(sum(nb_edus * (nb_edus - 1) for nb_edus in n_edus)) - raise ValueError('woopti') - return vocabulary, X def fit_transform(self, vectors): From 6dfa47f2bd4e29ba06649a05eef4f454d2c88404 Mon Sep 17 00:00:00 2001 From: moreymat Date: Fri, 9 Jun 2017 16:52:16 +0200 Subject: [PATCH 39/44] ENH backwards compat: option to load/dump labels from/to features file --- educe/learning/edu_input_format.py | 94 +++++++++++++++++++++++++----- 1 file changed, 79 insertions(+), 15 deletions(-) diff --git a/educe/learning/edu_input_format.py b/educe/learning/edu_input_format.py index 22babfe..5700e18 100644 --- a/educe/learning/edu_input_format.py +++ b/educe/learning/edu_input_format.py @@ -9,12 +9,11 @@ import six -from .svmlight_format import dump_svmlight_file -# WIP load_edu_input_file -# FIXME adapt to STAC -from educe.annotation import Span -from educe.corpus import FileId -from educe.rst_dt.annotation import EDU as RstEDU +# FIXME adapt load_edu_input_file to STAC +from educe.annotation import Span # WIP load_edu_input_file +from educe.corpus import FileId # WIP load_edu_input_file +from educe.learning.svmlight_format import dump_svmlight_file +from educe.rst_dt.annotation import EDU as RstEDU # WIP load_edu_input_file # pylint: disable=invalid-name # a lot of the names here are chosen deliberately to @@ -170,8 +169,19 @@ def labels_comment(class_mapping): return comment -def _load_labels(f): - """Actually read the label set""" +def _load_labels_file(f): + """Actually read the label set from a mapping file. + + Parameters + ---------- + f : str + Mapping file, each line pairs an integer index with a label. + + Returns + ------- + labels : dict from str to int + Mapping from relation label to integer. + """ labels = dict() for line in f: i, lbl = line.strip().split() @@ -180,8 +190,54 @@ def _load_labels(f): return labels -def load_labels(f): - """Read label set into a dictionary mapping labels to indices""" +def _load_labels_header(f): + """Actually read the label set from the header of a features file. + + Previous versions of educe dumped the labels in the header of the + svmlight features file: The first line was commented and contained + the list of labels, mapped to indices from 1 to n. + + Parameters + ---------- + f : str + Features file, whose first line is a comment with the list of labels. + + Returns + ------- + labels : dict from str to int + Mapping from relation label to integer. + """ + line = f.readline() + seq = line[1:].split()[1:] + labels = {lbl: idx for idx, lbl in enumerate(seq, start=1)} + labels['__UNK__'] = 0 + return labels + + +def load_labels(f, stored_as='header'): + """Read label set into a dictionary mapping labels to indices. + + Parameters + ---------- + f : str + File containing the labels. + stored_as : str, one of {'header', 'file'} + Storage mode of the labelset, as the `header` (commented first + line) of an svmlight features file, or as an independent `file` + where each line pairs an integer index with a label. + + Returns + ------- + labels : dict from str to int + Mapping from relation label to integer. + """ + if stored_as == 'header': + _load_labels = _load_labels_as_header + elif stored_as == 'file': + _load_labels = _load_labels_as_file + else: + raise ValueError( + "load_labels: stored_as must be one of {'header', 'file'}") with codecs.open(f, 'r', 'utf-8') as f: return _load_labels(f) @@ -204,7 +260,7 @@ def dump_labels(labelset, f): _dump_labels(labelset, f) -def dump_all(X_gen, y_gen, f, docs, instance_generator): +def dump_all(X_gen, y_gen, f, docs, instance_generator, class_mapping=None): """Dump a whole dataset: features (in svmlight) and EDU pairs. Parameters @@ -214,15 +270,17 @@ def dump_all(X_gen, y_gen, f, docs, instance_generator): y_gen : iterable of iterable of int Ground truth labels. - f : str Output features file path - docs : list of DocumentPlus Documents - instance_generator : function from doc to iterable of pairs TODO + class_mapping : dict(str, int), optional + Mapping from label to int. If None, it is ignored so you need + to check a proper call to dump_labels has been made elsewhere. + If not None, the list of labels ordered by index is written as + the header of the svmlight features file, as a comment line. """ # dump EDUs edu_input_file = f + '.edu_input' @@ -232,4 +290,10 @@ def dump_all(X_gen, y_gen, f, docs, instance_generator): dump_pairings_file((instance_generator(doc) for doc in docs), pairings_file) # dump vectorized pairings with label - dump_svmlight_file(X_gen, y_gen, f) + # the labelset will be written in a comment at the beginning of the + # svmlight file + if class_mapping is not None: + comment = labels_comment(class_mapping) + else: + comment = '' + dump_svmlight_file(X_gen, y_gen, f, comment=comment) From b55fa25bb7226dcaf68aaa41c5562d8bc64ee012 Mon Sep 17 00:00:00 2001 From: moreymat Date: Fri, 9 Jun 2017 17:15:47 +0200 Subject: [PATCH 40/44] MAINT minor changes in layout, docstring --- educe/rst_dt/learning/cmd/extract.py | 61 +++++++++------------------- 1 file changed, 19 insertions(+), 42 deletions(-) diff --git a/educe/rst_dt/learning/cmd/extract.py b/educe/rst_dt/learning/cmd/extract.py index b56df66..105ddcf 100644 --- a/educe/rst_dt/learning/cmd/extract.py +++ b/educe/rst_dt/learning/cmd/extract.py @@ -22,15 +22,16 @@ import educe.stac import educe.util -from educe.learning.edu_input_format import (dump_all, dump_labels, - load_labels) -from educe.learning.vocabulary_format import (dump_vocabulary, - load_vocabulary) -from ..args import add_usual_input_args -from ..doc_vectorizer import DocumentCountVectorizer, DocumentLabelExtractor +from educe.learning.edu_input_format import ( + dump_all, dump_labels, load_labels) +from educe.learning.vocabulary_format import ( + dump_vocabulary, load_vocabulary) +from educe.rst_dt.corenlp import CoreNlpParser from educe.rst_dt.corpus import RstDtParser +from educe.rst_dt.learning.args import add_usual_input_args +from educe.rst_dt.learning.doc_vectorizer import ( + DocumentCountVectorizer, DocumentLabelExtractor) from educe.rst_dt.ptb import PtbParser -from educe.rst_dt.corenlp import CoreNlpParser NAME = 'extract' @@ -131,11 +132,9 @@ def config_argparser(parser): # --------------------------------------------------------------------- # main # --------------------------------------------------------------------- - def extract_dump_instances(docs, instance_generator, feature_set, - lecsie_data_dir, vocabulary, - split_feat_space, labels, - live, ordered_pairs, output, corpus, + lecsie_data_dir, vocabulary, split_feat_space, + labels, live, ordered_pairs, output, corpus, file_split='corpus'): """Extract and dump instances. @@ -143,31 +142,23 @@ def extract_dump_instances(docs, instance_generator, feature_set, ---------- docs : list of DocumentPlus Documents - instance_generator : (string, function) Instance generator: the first element is a string descriptor of the instance generator, the second is the instance generator itself: a function from DocumentPlus to list of EDU pairs. - vocabulary : filepath Path to vocabulary - split_feat_space : string Splitter for feature space - labels : filepath? Path to labelset? - live : TODO TODO - ordered_pairs : boolean If True, DU pairs (instances) are ordered pairs, i.e. (src, tgt) <> (tgt, src). - output : string Path to the output directory, e.g. 'TMP/data'. - corpus : TODO TODO """ @@ -207,11 +198,9 @@ def extract_dump_instances(docs, instance_generator, feature_set, vocab = None min_df = 5 - vzer = DocumentCountVectorizer(instance_gen, - feature_set, + vzer = DocumentCountVectorizer(instance_gen, feature_set, lecsie_data_dir=lecsie_data_dir, - min_df=min_df, - vocabulary=vocab, + min_df=min_df, vocabulary=vocab, split_feat_space=split_feat_space) # pylint: disable=invalid-name # X, y follow the naming convention in sklearn @@ -229,9 +218,8 @@ def extract_dump_instances(docs, instance_generator, feature_set, labelset = load_labels(labels) else: labelset = None - labtor = DocumentLabelExtractor(instance_gen, - ordered_pairs=ordered_pairs, - labelset=labelset) + labtor = DocumentLabelExtractor( + instance_gen, ordered_pairs=ordered_pairs, labelset=labelset) if labels is not None: labtor.fit(docs) y_gen = labtor.transform(docs) @@ -316,16 +304,12 @@ def main(args): # retrieve parameters feature_set = args.feature_set live = args.parsing - - # NEW lecsie features - lecsie_data_dir = args.lecsie_data_dir - + lecsie_data_dir = args.lecsie_data_dir # NEW lecsie features # RST data # fileX docs are currently not supported by CoreNLP exclude_file_docs = args.corenlp_out_dir - rst_reader = RstDtParser(args.corpus, args, - coarse_rels=args.coarse, + rst_reader = RstDtParser(args.corpus, args, coarse_rels=args.coarse, fix_pseudo_rels=args.fix_pseudo_rels, nary_enc=args.nary_enc, exclude_file_docs=exclude_file_docs) @@ -423,15 +407,8 @@ def open_plus(doc): lambda doc: doc.all_edu_pairs( ordered=ordered_pairs)) split_feat_space = 'dir_sent' - # do the extraction extract_dump_instances(docs, instance_generator, feature_set, - lecsie_data_dir, - args.vocabulary, - split_feat_space, - args.labels, - live, - ordered_pairs, - args.output, - args.corpus, - file_split=args.file_split) + lecsie_data_dir, args.vocabulary, split_feat_space, + args.labels, live, ordered_pairs, args.output, + args.corpus, file_split=args.file_split) From 80caf8f7a2ced515a3a4f9e72842d5cc968726f7 Mon Sep 17 00:00:00 2001 From: moreymat Date: Fri, 9 Jun 2017 17:25:46 +0200 Subject: [PATCH 41/44] MAINT rm unnecessary imports --- educe/rst_dt/learning/cmd/extract.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/educe/rst_dt/learning/cmd/extract.py b/educe/rst_dt/learning/cmd/extract.py index 105ddcf..ce440fe 100644 --- a/educe/rst_dt/learning/cmd/extract.py +++ b/educe/rst_dt/learning/cmd/extract.py @@ -9,13 +9,8 @@ """ from __future__ import print_function -from collections import defaultdict -import csv import itertools -from glob import glob import os -import sys -import time import educe.corpus import educe.glozz From a855be05a01c22251084b74daa0bfd3fffada6f9 Mon Sep 17 00:00:00 2001 From: moreymat Date: Fri, 9 Jun 2017 17:34:01 +0200 Subject: [PATCH 42/44] FIX load_labels: set default to 'file' --- educe/learning/edu_input_format.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/educe/learning/edu_input_format.py b/educe/learning/edu_input_format.py index 5700e18..e063980 100644 --- a/educe/learning/edu_input_format.py +++ b/educe/learning/edu_input_format.py @@ -214,7 +214,7 @@ def _load_labels_header(f): return labels -def load_labels(f, stored_as='header'): +def load_labels(f, stored_as='file'): """Read label set into a dictionary mapping labels to indices. Parameters From fffbbff89cc7fb0fc462a651f3ed62f4f2c56e1a Mon Sep 17 00:00:00 2001 From: moreymat Date: Tue, 13 Jun 2017 15:18:11 +0200 Subject: [PATCH 43/44] MAINT minor refactoring, cleanups --- educe/rst_dt/learning/cmd/extract.py | 5 +---- educe/stac/learning/cmd/extract.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/educe/rst_dt/learning/cmd/extract.py b/educe/rst_dt/learning/cmd/extract.py index ce440fe..8ad3cda 100644 --- a/educe/rst_dt/learning/cmd/extract.py +++ b/educe/rst_dt/learning/cmd/extract.py @@ -159,11 +159,9 @@ def extract_dump_instances(docs, instance_generator, feature_set, """ # get instance generator and its descriptor instance_descr, instance_gen = instance_generator - # setup persistency if not os.path.exists(output): os.makedirs(output) - corpus_name = os.path.basename(corpus) if live: @@ -192,7 +190,6 @@ def extract_dump_instances(docs, instance_generator, feature_set, else: vocab = None min_df = 5 - vzer = DocumentCountVectorizer(instance_gen, feature_set, lecsie_data_dir=lecsie_data_dir, min_df=min_df, vocabulary=vocab, @@ -210,7 +207,7 @@ def extract_dump_instances(docs, instance_generator, feature_set, y_gen = itertools.repeat(0) else: if labels is not None: - labelset = load_labels(labels) + labelset = load_labels(labels, stored_as='file') else: labelset = None labtor = DocumentLabelExtractor( diff --git a/educe/stac/learning/cmd/extract.py b/educe/stac/learning/cmd/extract.py index e901070..ce65735 100755 --- a/educe/stac/learning/cmd/extract.py +++ b/educe/stac/learning/cmd/extract.py @@ -190,7 +190,7 @@ def main_pairs(args): instance_generator = lambda x: x.edu_pairs() if args.labels is not None: - labelset = load_labels(args.labels) + labelset = load_labels(args.labels, stored_as='file') labels = [lbl for lbl, idx in sorted(labelset.items(), key=lambda k, v: v)] # LabelVectorizer.__init__ automatically reserves the first three From 8c2bd10995aae2500bec2d76c02f5f7a55337984 Mon Sep 17 00:00:00 2001 From: moreymat Date: Fri, 30 Jun 2017 11:48:16 +0200 Subject: [PATCH 44/44] MAINT correct minor divergences to get closer to enh-dump-formats --- educe/learning/edu_input_format.py | 1 - educe/learning/keygroup_vectorizer.py | 3 --- educe/rst_dt/learning/cmd/extract.py | 1 - educe/stac/learning/cmd/extract.py | 1 - 4 files changed, 6 deletions(-) diff --git a/educe/learning/edu_input_format.py b/educe/learning/edu_input_format.py index e063980..39b4394 100644 --- a/educe/learning/edu_input_format.py +++ b/educe/learning/edu_input_format.py @@ -267,7 +267,6 @@ def dump_all(X_gen, y_gen, f, docs, instance_generator, class_mapping=None): ---------- X_gen : iterable of iterable of int arrays Feature vectors. - y_gen : iterable of iterable of int Ground truth labels. f : str diff --git a/educe/learning/keygroup_vectorizer.py b/educe/learning/keygroup_vectorizer.py index 9103152..c468e89 100644 --- a/educe/learning/keygroup_vectorizer.py +++ b/educe/learning/keygroup_vectorizer.py @@ -26,7 +26,6 @@ def _count_vocab(self, vectors, fixed_vocab=False): vectors : list of list of KeyGroup List of feature matrices, one list per doc, one line per sample. - fixed_vocab : boolean, defaults to False If True, use the vocabulary that hopefully has already been set during `fit()`. @@ -35,7 +34,6 @@ def _count_vocab(self, vectors, fixed_vocab=False): ------- vocabulary : dict(str, int) Mapping from features to integers. - X : list of list of list of tuple(int, float) List of feature matrices. """ @@ -84,7 +82,6 @@ def _count_vocab(self, vectors, fixed_vocab=False): # print('doc ', str(i)) # DEBUG x = feature_acc[current_row:next_row] X[-1].append(x) - return vocabulary, X def fit_transform(self, vectors): diff --git a/educe/rst_dt/learning/cmd/extract.py b/educe/rst_dt/learning/cmd/extract.py index 8ad3cda..e9ec9f8 100644 --- a/educe/rst_dt/learning/cmd/extract.py +++ b/educe/rst_dt/learning/cmd/extract.py @@ -253,7 +253,6 @@ def extract_dump_instances(docs, instance_generator, feature_set, else: raise ValueError('Unknown value for args.file_split : {}'.format( args.file_split)) - # dump labelset if labels is not None: # relative path to get a correct symlink diff --git a/educe/stac/learning/cmd/extract.py b/educe/stac/learning/cmd/extract.py index ce65735..c282662 100755 --- a/educe/stac/learning/cmd/extract.py +++ b/educe/stac/learning/cmd/extract.py @@ -248,7 +248,6 @@ def main_pairs(args): dump_all(X_gen, y_gen, out_file, dialogues, instance_generator) # end WIP - # dump vocabulary vocab_file = fp.join(outdir, '{corpus_name}.relations.sparse.vocab'.format(