- rel_scores, rels = rel_scores[arcs], rels[arcs]
- arc_loss = arc_loss(arcs, arc_scores)
- rel_loss = rel_loss(rels, rel_scores)
- loss = arc_loss + rel_loss
-
- return loss
-
- def decode(self, arc_scores, rel_scores, mask):
- if self.config.tree:
- # arc_preds = eisner(arc_scores, mask)
- raise NotImplemented('Give me some time...')
- else:
- arc_preds = arc_scores > 0
-
- rel_preds = tf.argmax(rel_scores, -1)
-
- return arc_preds, rel_preds