theseus

A library for differentiable nonlinear optimization

1778
128
Python

CircleCI License pypi PyPi Downloads Python pre-commit black PRs

A library for differentiable nonlinear optimization

PaperVideoTwitterWebpageTutorials

Theseus is an efficient application-agnostic library for building custom nonlinear optimization layers in PyTorch to support constructing various problems in robotics and vision as end-to-end differentiable architectures.

Differentiable nonlinear optimization provides a general scheme to encode inductive priors, as the objective function can be partly parameterized by neural models and partly with expert domain-specific differentiable models. The ability to compute gradients end-to-end is retained by differentiating through the optimizer which allows neural models to train on the final task loss, while also taking advantage of priors captured by the optimizer.

See list of papers published using Theseus for examples across various application domains.


Current Features

Application agnostic interface

Our implementation provides an easy to use interface to build custom optimization layers and plug them into any neural architecture. Following differentiable features are currently available:

Efficiency based design

We support several features that improve computation times and memory consumption:

Getting Started

Prerequisites

  • We strongly recommend you install Theseus in a venv or conda environment with Python 3.8-3.10.
  • Theseus requires torch installation. To install for your particular CPU/CUDA configuration, follow the instructions in the PyTorch website.
  • For GPU support, Theseus requires nvcc to compile custom CUDA operations. Make sure it matches the version used to compile pytorch with nvcc --version. If not, install it and ensure its location is on your system’s $PATH variable.
  • Theseus also requires suitesparse, which you can install via:
    • sudo apt-get install libsuitesparse-dev (Ubuntu).
    • conda install -c conda-forge suitesparse (Mac).

Installing

  • pypi

    pip install theseus-ai
    

    We currently provide wheels with our CUDA extensions compiled using CUDA 11.6 and Python 3.10.
    For other CUDA versions, consider installing from source or using our
    build script.

    Note that pypi installation doesn’t include our experimental Theseus Labs.
    For this, please install from source.

  • From source

    The simplest way to install Theseus from source is by running the following (see further below to also include BaSpaCho)

    git clone https://github.com/facebookresearch/theseus.git && cd theseus
    pip install -e .
    

    If you are interested in contributing to Theseus, instead install

    pip install -e ".[dev]"
    pre-commit install
    

    and follow the more detailed instructions in CONTRIBUTING.

  • Installing BaSpaCho extensions from source

    By default, installing from source doesn’t include our BaSpaCho sparse solver extension. For this, follow these steps:

    1. Compile BaSpaCho from source following instructions here. We recommend using flags -DBLA_STATIC=ON -DBUILD_SHARED_LIBS=OFF.

    2. Run

      git clone https://github.com/facebookresearch/theseus.git && cd theseus
      BASPACHO_ROOT_DIR=<path/to/root/baspacho/dir> pip install -e .
      

      where the BaSpaCho root dir must have the binaries in the subdirectory build.

Running unit tests (requires dev installation)

python -m pytest tests

By default, unit tests include tests for our CUDA extensions. You can add the option -m "not cudaext"
to skip them when installing without CUDA support. Additionally, the tests for sparse solver BaSpaCho are automatically
skipped when its extlib is not compiled.

Examples

Simple example. This example is fitting the curve $y$ to a dataset of $N$ observations $(x,y) \sim D$. This is modeled as an Objective with a single CostFunction that computes the residual $y - v e^x$. The Objective and the GaussNewton optimizer are encapsulated into a TheseusLayer. With Adam and MSE loss, $x$ is learned by differentiating through the TheseusLayer.

import torch
import theseus as th

x_true, y_true, v_true = read_data() # shapes (1, N), (1, N), (1, 1)
x = th.Variable(torch.randn_like(x_true), name="x")
y = th.Variable(y_true, name="y")
v = th.Vector(1, name="v") # a manifold subclass of Variable for optim_vars

def error_fn(optim_vars, aux_vars): # returns y - v * exp(x)
    x, y = aux_vars
    return y.tensor - optim_vars[0].tensor * torch.exp(x.tensor)

objective = th.Objective()
cost_function = th.AutoDiffCostFunction(
    [v], error_fn, y_true.shape[1], aux_vars=[x, y],
    cost_weight=th.ScaleCostWeight(1.0))
objective.add(cost_function)
layer = th.TheseusLayer(th.GaussNewton(objective, max_iterations=10))

phi = torch.nn.Parameter(x_true + 0.1 * torch.ones_like(x_true))
outer_optimizer = torch.optim.Adam([phi], lr=0.001)
for epoch in range(10):
    solution, info = layer.forward(
        input_tensors={"x": phi.clone(), "v": torch.ones(1, 1)},
        optimizer_kwargs={"backward_mode": "implicit"})
    outer_loss = torch.nn.functional.mse_loss(solution["v"], v_true)
    outer_loss.backward()
    outer_optimizer.step()

See tutorials, and robotics and vision examples to learn about the API and usage.

Citing Theseus

If you use Theseus in your work, please cite the paper with the BibTeX below.

@article{pineda2022theseus,
  title   = {{Theseus: A Library for Differentiable Nonlinear Optimization}},
  author  = {Luis Pineda and Taosha Fan and Maurizio Monge and Shobha Venkataraman and Paloma Sodhi and Ricky TQ Chen and Joseph Ortiz and Daniel DeTone and Austin Wang and Stuart Anderson and Jing Dong and Brandon Amos and Mustafa Mukadam},
  journal = {Advances in Neural Information Processing Systems},
  year    = {2022}
}

License

Theseus is MIT licensed. See the LICENSE for details.

Additional Information

Theseus is made possible by the following contributors:

Made with contrib.rocks.