Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

Torchscriptable Conformer + high-level "simple" object for decoding, alignments, posteriors, plotting them, etc. #206

Merged
merged 10 commits into from
Jun 8, 2021

Conversation

pzelasko
Copy link
Collaborator

@pzelasko pzelasko commented Jun 2, 2021

Note: this is unfinished, I put together multiple code snippets but didn't polish it in any way. It could be a part of Icefall with a bit of work and documentation.

You can use it e.g. in a jupyter notebook to plot posteriors and alignments, note the high level methods (which all accept cuts as inputs):

  • compute_features
  • compute_posteriors
  • decode
  • align
  • align_ctm
  • plot_alignments
  • plot_posteriors

You can get the plots from #203 with this.

@danpovey
Copy link
Contributor

danpovey commented Jun 3, 2021

I don't think I want to go in this direction of having things all bound together in classes -- at least for the time being.
I want to go with simple utility functions with interfaces as small as possible.
However, I am open to merging this simply as a useful reference for things we might need to do, like plotting.

@pzelasko
Copy link
Collaborator Author

pzelasko commented Jun 3, 2021

That makes sense to me TBH because I am not even sure how to make a class like this generic enough to handle the different types of models, topologies, techniques, etc. that we're going to introduce. Anyway, I find it helpful to work with the "current best" model in other projects. I'll test it a bit more to make sure everything works as intended and I'll merge then.

@danpovey
Copy link
Contributor

danpovey commented Jun 4, 2021 via email

@csukuangfj
Copy link
Collaborator

csukuangfj commented Jun 7, 2021

I find it helpful to work with the "current best" model in other projects.

I would propose to let the code support any PyTorch model, at least for those supported by Torch Script.

The user only needs to provide a .pt file, which contains everything needed to run the model.

I just write a small demo (see below) to show that the idea is feasible. You can see that with demo.pt at hand, we don't need
the definition of the model to run its forward function.

I believe the following small utilities would be helpful:

(1) compute-post

  • Input: Features
  • Output: The output of the model

(2) decode

  • Input 1: The output of compute-post
  • Input 2: HLG
  • Output: The decoding result, e.g., word sequences, phone sequences, etc. We can represent it using
    class Alignment:
    # The key of the dict indicates the type of the alignment,
    # e.g., ilabel, phone_label, etc.
    #
    # The value of the dict is the actual alignments.
    # If the alignments are frame-wise and if the sampling rate
    # is available, they can be converted to CTM format, like the
    # one used in Lhotse
    value: Dict[str, Union[List[int], List[str]]]

(3) show-ali

  • Input: The output of decode
  • Output: some type of visualization. E.g, we can use praat, which supports TextGrid format, which can be converted from CTM format. An example visualization is shown below. There is a Python module https://parselmouth.readthedocs.io/en/stable/index.html for praat, so I believe it is possible to obtain the following figure from the output of decode by some Python code, without opening a GUI window.

Screen Shot 2021-06-07 at 4 24 03 PM


#!/usr/bin/env python3

import torch
import torch.nn as nn


class Model(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        return self.linear(x)


@torch.no_grad()
def main():
    m = Model(in_dim=2, out_dim=3)

    x = torch.tensor([1, 2.0])
    y = m(x)

    script_module = torch.jit.script(m)
    script_module.save('demo.pt')

    new_m = torch.jit.load('demo.pt')
    new_y = new_m(x)
    print(y)
    print(new_y)


if __name__ == '__main__':
    main()

@danpovey
Copy link
Contributor

danpovey commented Jun 7, 2021

Perhaps someone can test whether our current models are supported by TorchScript, or at least whether it would be possible to make them supported?

@pzelasko
Copy link
Collaborator Author

pzelasko commented Jun 7, 2021

I was just able to convert the Conformer to torchscript with some changes, I'll make it a part of this PR.

@danpovey
Copy link
Contributor

danpovey commented Jun 7, 2021 via email

@pzelasko
Copy link
Collaborator Author

pzelasko commented Jun 7, 2021

OK, a summary of this PR:

  • all the methods in the ASR class work correctly
  • Conformer is TorchScriptable
  • if I manually create a TorchScript checkpoint, everything seems to work OK

Things that still don't work:

  • I don't know if TorchScript is intended to work in training (I think so?) but it slows the training down by a lot. I think this is probably some sort of incorrect usage
  • I'm not sure if the TorchScript checkpoint saving is working alright, my training to test this is still running

I'm not sure if I have the capacity to work on this further for now -- in any case, Conformer is now scriptable which should open the way for others.

@pzelasko pzelasko changed the title WIP: high-level "simple" object for decoding, alignments, posteriors, plotting them, etc. Torchscriptable Conformer + high-level "simple" object for decoding, alignments, posteriors, plotting them, etc. Jun 7, 2021
@pzelasko
Copy link
Collaborator Author

pzelasko commented Jun 7, 2021

Saving models to TorchScript works during training with --torchscript-epoch <start-saving-epoch> flag, without any issues of speed, because it converts to TorchScript just before storing. The issues are when we train using a ScriptModule (--torchscript true flag).

It is OK to merge from my side -- please review and merge if it's adequate.

@pzelasko
Copy link
Collaborator Author

pzelasko commented Jun 7, 2021

One last remark which I forgot to mention -- I did only very naive benchmarking by running in Jupyter the following snippet:

%%timeit
with torch.no_grad():
    model(features, supervisions)

The improvement from normal to TorchScripted model was small -- 140ms vs 130ms. It used a V100 GPU with ~30 cuts in the batch. So it's only the training time when I noticed the slowdown.

@mthrok
Copy link

mthrok commented Jun 7, 2021

One last remark which I forgot to mention -- I did only very naive benchmarking by running in Jupyter the following snippet:

%%timeit
with torch.no_grad():
    model(features, supervisions)

The improvement from normal to TorchScripted model was small -- 140ms vs 130ms. It used a V100 GPU with ~30 cuts in the batch. So it's only the training time when I noticed the slowdown.

FYI: Though originally TorchScript was advertised for performance, it is now mainly solving the problem of deployment. The resulting object is deployable to C++/iOS/Android. The performance improvement efforts were moved to AI compiler, so in general we can't expect performance improvement. I say it's lucky if you get performance improvement if you get any.

This might not be relevant to your application, but when scripting a model, it is possible to perform irreversible operation. For example, recently I added TorchScript-able wav2vec2 to torchaudio. For the sake of supporting quantization, I added a hook for scripting that removes weight normalization forward hook. My rational was that the model was mainly intended for inference so removing a hook is fine. However if a model is scripted during the training, then the wav2vec2 model from torch audio is not compatible with snowfall in kind of unexpected way.

Since torch script object file is only architecture and pareters, it feels to me that creating a tool that makes scripted model from training checkpoint file is simpler.

But however I do not know the design principles of snowfall or the context of this work, if that's desired, I think it's okay.

@mthrok
Copy link

mthrok commented Jun 7, 2021

If you are looking for a way to speed up training, quantization aware training is one approach. I heard there was a case where it both improved training time AND accuracy at the same time.

@pzelasko
Copy link
Collaborator Author

pzelasko commented Jun 7, 2021

Thanks @mthrok, that makes a lot of sense. I wondered if people use TorchScript to speed up the training but now it's clear it's not the case. BTW could you elaborate on the AI compiler? Is it Glow, or sth else?

The --torchscript-epoch option is basically what a checkpoint conversion tool would have done -- except to write such a tool, we would need to provide all the info such as architecture, hparams, etc. to it, so I guess the hope is that we can just store a torchscripted model in the training script not to have to know all the hparams needed to instantiate the model for downstream applications. But maybe @danpovey and @csukuangfj will have a different view. In any case we should also support weight averaging before storing the torchscripted model as it consistently improves the results.

@mthrok
Copy link

mthrok commented Jun 7, 2021

Thanks @mthrok, that makes a lot of sense. I wondered if people use TorchScript to speed up the training but now it's clear it's not the case. BTW could you elaborate on the AI compiler? Is it Glow, or sth else?

Yeah I think one of them. (but I do not know much so please take it with a grain of salt 😥) There is also torch.fx.

@pzelasko
Copy link
Collaborator Author

pzelasko commented Jun 8, 2021

There’s also an ONNX exporter here that converts TorchScript modules https://pytorch.org/docs/stable/onnx.html

maybe it’s worth looking into..

@danpovey
Copy link
Contributor

danpovey commented Jun 8, 2021

Thanks!
I'll merge so we don't get too out of date..

@danpovey danpovey merged commit 2d38a89 into k2-fsa:master Jun 8, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants