Learning Sparse Neural Networks through L0 Regularization

Model compression is great, it prunes some parameters out of a large network. But how do we know which parameters are useless? This paper proposes a practical method to force the model to use less parameters in order to yield a sparse model. Their conceptually attractive approach is the L_0 norm regularization defined as the number of non-zero parameters.

\def\!#1{\boldsymbol{#1}} \def\*#1{\mathbf{#1}} \DeclareMathOperator*{\argmin}{arg\,min} \DeclareMathOperator*{\argmax}{arg\,max} \begin{align} \mathcal{R}(\!\theta) = \frac{1}{N}\bigg(\sum_{i=1}^{N}\mathcal{L}\big(h&(\*x_i;\!\theta), \*y_i\big)\bigg) + \lambda \|\!\theta\|_0,\qquad \|\!\theta\|_0 = \sum_{j=1}^{|\theta|}\mathbb{I}[\theta_j \neq 0], \label{eq:erm_l0}\\ & \!\theta^* = \argmin_{\!\theta}\{\mathcal{R}(\!\theta)\}\nonumber, \end{align}

Unfortunately, L_0 norm is not differentiable so we want to relax it. We can put a binary gate on top of each parameter:

\begin{align} \theta_j = \tilde{\theta}_j z_j, \qquad z_j \in \{0, 1\}, \qquad \tilde{\theta}_j \neq 0, \qquad \|\!\theta\|_0 = \sum_{j=1}^{|\theta|}z_j, \end{align}

Where each gate z_j is controlled by a Bernoulli distribution q(z_j|\pi_j) = \text{Bern}(\pi_j). Then Eq.1 can be reformulated as:

\DeclareMathOperator{\E}{\mathbb{E}} \begin{align} \mathcal{R}(\tilde{\!\theta}, \!\pi) & = \E_{q(\*z|\!\pi)}\bigg[\frac{1}{N}\bigg(\sum_{i=1}^{N}\mathcal{L}\big(h(\*x_i; \tilde{\!\theta}\odot\*z), \*y_i\big)\bigg)\bigg] + \lambda \sum_{j=1}^{|\theta|}\pi_j,\label{eq:erm_bern}\\ & \tilde{\!\theta}^*, \!\pi^* = \argmin_{\tilde{\!\theta}, \!\pi}\{\mathcal{R}(\tilde{\!\theta}, \!\pi)\}\nonumber, \end{align}

Although now L_0 becomes deferential but the first term becomes problematic due to the discrete \*z.

Since Bernoulli distribution doesn’t work, the author propose an alternative way to smooth the loss. Let \*s be a continuous random variable with a distribution q(\*s) parameterized with \*\phi, then let the hard-sigmoid rectification of \*s control the gate.

\begin{align} \*s &\sim q(\*s | \!\phi)\\ \*z &= \min(\*1, \max(\*0, \*s)). \end{align}

As q(\*s) is continuous, we can use its cumulative distribution function to calculate the probability of the gate being non-zero:

\begin{align} q(\*z \neq 0| \!\phi) = 1 - Q(\*s \leq 0|\!\phi), \end{align}

Then, the objective is smoothed as:

\begin{align} \mathcal{R}(\tilde{\!\theta}, \!\phi) & = \E_{q(\*s|\!\phi)}\bigg[\frac{1}{N}\bigg(\sum_{i=1}^{N}\mathcal{L}\big(h(\*x_i; \tilde{\!\theta}\odot g(\*s)), \*y_i\big)\bigg)\bigg] + \lambda \sum_{j=1}^{|\theta|}\big(1 - Q(s_j \leq 0|\phi_j)\big),\label{eq:erm_hc}\\ & \qquad \tilde{\!\theta}^*, \!\phi* = \argmin_{\tilde{\!\theta}, \!\phi}\{\mathcal{R}(\tilde{\!\theta}, \!\phi)\}\nonumber, \quad g(\cdot) = \min(1, \max(0, \cdot)). \end{align}

To implement q(\*s), we can express it as deterministic transform f(\cdot) over a parameter free noise distribution p(\!\epsilon) and the model parameters.

\begin{align} \mathcal{R}(\tilde{\!\theta}, \!\phi) & = \E_{p(\!\epsilon)}\bigg[\frac{1}{N}\bigg(\sum_{i=1}^{N}\mathcal{L}\big(h(\*x_i; \tilde{\!\theta}\odot g(f(\!\phi, \!\epsilon))), \*y_i\big)\bigg)\bigg] + \lambda \sum_{j=1}^{|\theta|}\big(1 - Q(s_j \leq 0|\phi_j)\big), \end{align}

Although the expectation over p(\!\epsilon) is generally not tractable, it can be approximated by Monte Carlo method.

\begin{align} \hat{\mathcal{R}}(\tilde{\!\theta}, \!\phi) & = \frac{1}{L}\sum_{l=1}^{L}\bigg(\frac{1}{N}\bigg(\sum_{i=1}^{N}\mathcal{L}\big(h(\*x_i; \tilde{\!\theta}\odot\*z^{(l)}), \*y_i\big)\bigg)\bigg) + \lambda \sum_{j=1}^{|\theta|}\big(1 - Q(s_j \leq 0|\phi_j)\big)\nonumber\\ & = \mathcal{L}_E(\tilde{\!\theta}, \!\phi) + \lambda \mathcal{L}_C(\!\phi), \quad \text{where} \; \*z^{(l)} = g(f(\!\phi, \!\epsilon^{(l)})) \; \text{and} \; \!\epsilon^{(l)} \sim p(\!\epsilon).\label{eq:erm_mc_hc} %& \qquad \text{where} \quad \*z^{(l)} = g(f(\!\epsilon^{(l)}, \!\phi)) \quad \text{and} \quad \!\epsilon^{(l)} \sim p(\!\epsilon), \nonumber \end{align}

Thus, run L rounds of random sampling of p(\!\epsilon) and use the average as the approximation.

For the noise distribution, the author propose a hard concrete distribution, which is a stretched version of the binary concrete distribution, where \log\alpha is the location and \beta is the temperature.

\begin{align} u \sim \mathcal{U}(0, 1), \quad s = \text{Sigmoid}& \big((\log u - \log (1 - u) + \log \alpha) / \beta\big), \quad \bar{s} = s(\zeta - \gamma) + \gamma,\\ %s\sim q_s(s|\phi), \quad z & = \min(1, \max(0, \bar{s})). \end{align}

At test time, just use a discrete sample:

\begin{align} \hat{\*z} = \min(\*1, \max(\*0, \text{Sigmoid}(\log\!\alpha)(\zeta - \gamma) + \gamma)), \qquad \!\theta^* = \tilde{\!\theta}^* \odot \hat{\*z}\label{eq:fn_theta}. \end{align}

It’s also possible to combine L_0 norm with L_2 norm. It’s strait-forward to see the expected L_2 under Bernoulli gating can be expressed as:

\begin{align} \E_{q(\*z|\!\pi)}\big[\|\!\theta\|_2^2\big] = \sum_{j=1}^{|\theta|}\E_{q(z_j|\pi_j)}\big[z_j^2\tilde{\theta}_j^2\big] = \sum_{j=1}^{|\theta|}\pi_j\tilde{\theta}_j^2, \end{align}

It’s not differentiable but we can apply similar smooth techniques. Let’s assume \pi_j is proportional to the negative log density of a zero mean Gaussian prior with a standard deviation of \sigma=1 when z=0 else \sigma=z. Thus, the standard deviation of \hat{\theta} = \frac{\theta}{\sigma} will be 1 and its L_2 norm is as follows.

\begin{align} \E_{q(\*z|\!\phi)}\big[\|\hat{\!\theta}\|^2_2\big] & = \sum_{j=1}^{|\theta|}\bigg( Q_{\bar{s}_j}(0|\phi_j)\frac{0}{1} + \big(1 - Q_{\bar{s}_j}(0|\phi_j)\big)\E_{q(z_j|\phi_j, \bar{s}_j > 0)}\bigg[\frac{\tilde{\theta}_j^2 \cancel{z_j^2}}{\cancel{z_j^2}}\bigg]\bigg) \nonumber \\ & = \sum_{j=1}^{|\theta|}\big(1 - Q_{\bar{s}_j}(0|\phi_j)\big)\tilde{\theta}_j^2. \end{align}

We can also let a group of parameters share the same gate, which will give us similar penalties summed over groups.

\begin{align} \E_{q(\*z|\!\phi)}\bigg[\|\!\theta\|_0\bigg] & = \sum_{g=1}^{|G|}|g|\bigg(1 - Q(s_g \leq 0|\phi_g)\bigg) \\ \E_{q(\*z|\!\phi)}\bigg[\|\hat{\!\theta}\|^2_2\bigg] &= \sum_{g=1}^{|G|}\bigg(\big(1 - Q(s_g \leq 0|\phi_g)\big)\sum_{j=1}^{|g|}\tilde{\theta}^2_j\bigg). \end{align}

Overall Recommendation

  • 5: Transformative: This paper is likely to change our field. It should be considered for a best paper award.
  • 4.5: Exciting: It changed my thinking on this topic. I would fight for it to be accepted.
  • 4: Strong: I learned a lot from it. I would like to see it accepted.
  • 3.5: Leaning positive: It can be accepted more or less in its current form. However, the work it describes is not particularly exciting and/or inspiring, so it will not be a big loss if people don’t see it in this conference.
  • 3: Ambivalent: It has merits (e.g., it reports state-of-the-art results, the idea is nice), but there are key weaknesses (e.g., I didn’t learn much from it, evaluation is not convincing, it describes incremental work). I believe it can significantly benefit from another round of revision, but I won’t object to accepting it if my co-reviewers are willing to champion it.
  • 2.5: Leaning negative: I am leaning towards rejection, but I can be persuaded if my co-reviewers think otherwise.
  • 2: Mediocre: I would rather not see it in the conference.
  • 1.5: Weak: I am pretty confident that it should be rejected.
  • 1: Poor: I would fight to have it rejected.