Skip to content

Latest commit

 

History

History
489 lines (364 loc) · 28.7 KB

File metadata and controls

489 lines (364 loc) · 28.7 KB

Lan truyền Ngược qua Thời gian

🏷️sec_bptt

Cho đến nay chúng ta liên tục nhắc đến những vấn đề như bùng nổ gradient, tiêu biến gradient, cắt xén lan truyển ngược và sự cần thiết của việc tách đồ thị tính toán. Ví dụ, trong phần trước chúng ta gọi hàm s.detach() trên chuỗi. Vì muốn nhanh chóng xây dựng và quan sát cách một mô hình hoạt động nên những vấn đề này chưa được giải thích một cách đầy đủ. Trong phần này chúng ta sẽ nghiên cứu sâu và chi tiết hơn về lan truyền ngược cho các mô hình chuỗi và giải thích nguyên lý toán học đằng sau. Để hiểu chi tiết hơn về tính ngẫu nhiên và lan truyền ngược, hãy tham khảo bài báo :cite:Tallec.Ollivier.2017.

Chúng ta đã thấy một vài hậu quả của bùng nổ gradient khi lập trình mạng nơ-ron hồi tiếp (:numref:sec_rnn_scratch). Cụ thể, nếu bạn đã làm xong bài tập ở phần đó, bạn sẽ thấy rằng việc gọt gradient đóng vai trò rất quan trọng để đảm bảo mô hình hội tụ. Để có cái nhìn rõ hơn về vấn đề này, trong phần này chúng ta sẽ xem xét cách tính gradient cho các mô hình chuỗi. Lưu ý rằng, về mặt khái niệm thì không có gì mới ở đây. Sau cùng, chúng ta vẫn chỉ đơn thuần áp dụng các quy tắc dây chuyền để tính gradient. Tuy nhiên, việc ôn lại lan truyền ngược (:numref:sec_backprop) vẫn rất hữu ích.

Lượt truyền xuôi trong mạng nơ-ron hồi tiếp tương đối đơn giản. Lan truyền ngược qua thời gian thực chất là một ứng dụng cụ thể của lan truyền ngược trong mạng nơ-ron hồi tiếp. Nó đòi hỏi chúng ta mở rộng mạng nơ-ron hồi tiếp theo từng bước thời gian một để thu được sự phụ thuộc giữa các biến mô hình và các tham số. Sau đó, dựa trên quy tắc dây chuyền, chúng ta áp dụng lan truyền ngược để tính toán và lưu các giá trị gradient. Vì chuỗi có thể khá dài nên sự phụ thuộc trong chuỗi cũng có thể rất dài. Ví dụ, đối với một chuỗi gồm 1000 ký tự, ký tự đầu tiên có thể ảnh hưởng đáng kể tới ký tự ở vị trí 1000. Điều này không thực sự khả thi về mặt tính toán (cần quá nhiều thời gian và bộ nhớ) và nó đòi hỏi hơn 1000 phép nhân ma trận-vector trước khi thu được các giá trị gradient khó nắm bắt này. Đây là một quá trình chứa đầy sự bất định về mặt tính toán và thống kê. Trong phần tiếp theo chúng ta sẽ làm sáng tỏ những gì sẽ xảy ra và cách giải quyết vấn đề này trong thực tế.

Mạng Hồi tiếp Giản thể

Hãy bắt đầu với một mô hình đơn giản về cách mà mạng RNN hoạt động. Mô hình này bỏ qua các chi tiết cụ thể của trạng thái ẩn và cách trạng thái này được cập nhật. Những chi tiết này không quan trọng đối với việc phân tích dưới đây mà chỉ khiến các ký hiệu trở nên lộn xộn và phức tạp quá mức. Trong mô hình đơn giản này, chúng ta ký hiệu $h_t$ là trạng thái ẩn, $x_t$ là đầu vào, và $o_t$ là đầu ra tại bước thời gian $t$. Bên cạnh đó, $w_h$$w_o$ tương ứng với trọng số của các trạng thái ẩn và tầng đầu ra. Kết quả là, các trạng thái ẩn và kết quả đầu ra tại mỗi bước thời gian có thể được giải thích như sau

$$h_t = f(x_t, h_{t-1}, w_h) \text{ và } o_t = g(h_t, w_o).$$

Do đó, chúng ta có một chuỗi các giá trị ${\ldots, (h_{t-1}, x_{t-1}, o_{t-1}), (h_{t}, x_{t}, o_t), \ldots}$ phụ thuộc vào nhau thông qua phép tính đệ quy. Lượt truyền xuôi khá đơn giản. Những gì chúng ta cần là lặp qua từng bộ ba $(x_t, h_t, o_t)$ một. Sau đó, sự khác biệt giữa kết quả đầu ra $o_t$ và các giá trị mục tiêu mong muốn $y_t$ được tính bằng một hàm mục tiêu

$$L(x, y, w_h, w_o) = \sum_{t=1}^T l(y_t, o_t).$$

Đối với lan truyền ngược, mọi thứ lại phức tạp hơn một chút, đặc biệt là khi chúng ta tính gradient theo các tham số $w_h$ của hàm mục tiêu $L$. Cụ thể, theo quy tắc dây chuyền ta có

$$\begin{aligned} \partial_{w_h} L & = \sum_{t=1}^T \partial_{w_h} l(y_t, o_t) \\ & = \sum_{t=1}^T \partial_{o_t} l(y_t, o_t) \partial_{h_t} g(h_t, w_h) \left[ \partial_{w_h} h_t\right]. \end{aligned}$$

Ta có thể tính phần đầu tiên và phần thứ hai của đạo hàm một cách dễ dàng. Phần thứ ba $\partial_{w_h} h_t$ khiến mọi thứ trở nên khó khăn, vì chúng ta cần phải tính toán ảnh hưởng của các tham số tới $h_t$.

Để tính được gradient ở trên, giả sử rằng chúng ta có ba chuỗi ${a_{t}},{b_{t}},{c_{t}}$ thỏa mãn $a_{0}=0, a_{1}=b_{1}$$a_{t}=b_{t}+c_{t}a_{t-1}$ với $t=1, 2,\ldots$. Sau đó, với $t\geq 1$ ta có

$$a_{t}=b_{t}+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t}c_{j}\right)b_{i}.$$ :eqlabel:eq_bptt_at

Bây giờ chúng ta áp dụng :eqref:eq_bptt_at với

$$a_t = \partial_{w_h}h_{t},$$

$$b_t = \partial_{w_h}f(x_{t},h_{t-1},w_h), $$

$$c_t = \partial_{h_{t-1}}f(x_{t},h_{t-1},w_h).$$

Vì vậy, công thức $a_{t}=b_{t}+c_{t}a_{t-1}$ trở thành phép đệ quy dưới đây

$$ \partial_{w_h}h_{t}=\partial_{w_h}f(x_{t},h_{t-1},w)+\partial_{h}f(x_{t},h_{t-1},w_h)\partial_{w_h}h_{t-1}. $$

Sử dụng :eqref:eq_bptt_at, phần thứ ba sẽ trở thành

$$ \partial_{w_h}h_{t}=\partial_{w_h}f(x_{t},h_{t-1},w_h)+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t}\partial_{h_{j-1}}f(x_{j},h_{j-1},w_h)\right)\partial_{w_h}f(x_{i},h_{i-1},w_h). $$

Dù chúng ta có thể sử dụng quy tắc dây chuyền để tính $\partial_w h_t$ một cách đệ quy, dây chuyền này có thể trở nên rất dài khi giá trị $t$ lớn. Hãy cùng thảo luận về một số chiến lược để giải quyết vấn đề này.

  • Tính toàn bộ tổng. Cách này rất chậm và gradient có thể bùng nổ vì những thay đổi nhỏ trong các điều kiện ban đầu cũng có khả năng ảnh hưởng đến kết quả rất nhiều. Điều này tương tự như trong hiệu ứng cánh bướm, khi những thay đổi rất nhỏ trong điều kiện ban đầu dẫn đến những thay đổi không cân xứng trong kết quả. Đây thực sự là điều không mong muốn khi xét tới mô hình mà chúng ta muốn ước lượng. Sau cùng, chúng ta đang cố tìm kiếm một bộ ước lượng mạnh mẽ và có khả năng khái quát tốt. Do đó chiến lược này hầu như không bao giờ được sử dụng trong thực tế.
  • Cắt xén tổng sau $\tau$ bước. Cho đến giây phút hiện tại, đây là những gì chúng ta đã thảo luận. Điều này dẫn tới một phép xấp xỉ của gradient, đơn giản bằng cách kết thúc tổng trên tại $\partial_w h_{t-\tau}$. Do đó lỗi xấp xỉ là $\partial_h f(x_t, h_{t-1}, w) \partial_w h_{t-1}$ (nhân với tích của gradient liên quan đến $\partial_h f$). Trong thực tế, chiến lược này hoạt động khá tốt. Phương pháp này thường được gọi là BPTT (backpropagation through time --- lan truyền ngược qua thời gian) bị cắt xén. Một trong những hệ quả của phương pháp này là mô hình sẽ tập trung chủ yếu vào ảnh hưởng ngắn hạn thay vì dài hạn. Đây thực sự là điều mà chúng ta mong muốn, vì nó hướng sự ước lượng tới các mô hình đơn giản và ổn định hơn.
  • Cắt xén Ngẫu nhiên. Cuối cùng, chúng ta có thể thay thế $\partial_{w_h} h_t$ bằng một biến ngẫu nhiên có giá trị kỳ vọng đúng nhưng vẫn cắt xén chuỗi.
  • Điều này có thể đạt được bằng cách sử dụng một chuỗi các $\xi_t$ trong đó $E[\xi_t] = 1$, $P(\xi_t = 0) = 1-\pi$$P(\xi_t = \pi^{-1}) = \pi$.
  • Chúng ta sẽ sử dụng chúng thay vì gradient:

$$z_t = \partial_w f(x_t, h_{t-1}, w) + \xi_t \partial_h f(x_t, h_{t-1}, w) \partial_w h_{t-1}.$$

Từ định nghĩa của $\xi_t$, ta có $E[z_t] = \partial_w h_t$. Bất cứ khi nào $\xi_t = 0$, khai triển sẽ kết thúc tại điểm đó. Điều này dẫn đến một tổng trọng số của các chuỗi có chiều dài biến thiên, trong đó chuỗi dài sẽ hiếm hơn nhưng được đánh trọng số cao hơn tương ứng. :cite:Tallec.Ollivier.2017 đưa ra đề xuất này trong bài báo nghiên cứu của họ. Không may, dù phương pháp này khá hấp dẫn về mặt lý thuyết, nó lại không tốt hơn phương pháp cắt xén đơn giản, nhiều khả năng do các yếu tố sau. Thứ nhất, tác động của một quan sát đến quá khứ sau một vài lượt lan truyền ngược đã là tương đối đủ để nắm bắt các phụ thuộc trên thực tế. Thứ hai, phương sai tăng lên làm phản tác dụng của việc có gradient chính xác hơn. Thứ ba, ta thực sự muốn các mô hình có khoảng tương tác ngắn. Do đó, BPTT có một hiệu ứng điều chuẩn nhỏ mà có thể có ích.

Từ trên xuống dưới: BPTT ngẫu nhiên, BPTT bị cắt xén đều và BPTT đầy đủ 🏷️fig_truncated_bptt

:numref:fig_truncated_bptt minh họa ba trường hợp trên khi phân tích một số từ đầu tiên trong Cỗ máy Thời gian:

  • Dòng đầu tiên biểu diễn sự cắt xén ngẫu nhiên, chia văn bản thành các phần có độ dài biến thiên.
  • Dòng thứ hai biểu diễn BPTT bị cắt xén đều, chia văn bản thành các phần có độ dài bằng nhau.
  • Dòng thứ ba là BPTT đầy đủ, dẫn đến một biểu thức không khả thi về mặt tính toán.

Đồ thị Tính toán

Để minh họa trực quan sự phụ thuộc giữa các biến và tham số mô hình trong suốt quá trình tính toán của mạng nơ-ron hồi tiếp, ta có thể vẽ đồ thị tính toán của mô hình, như trong :numref:fig_rnn_bptt. Ví dụ, việc tính toán trạng thái ẩn ở bước thời gian 3, $\mathbf{h}3$, phụ thuộc vào các tham số $\mathbf{W}{hx}$ và $\mathbf{W}_{hh}$ của mô hình, trạng thái ẩn ở bước thời gian trước đó $\mathbf{h}_2$, và đầu vào ở bước thời gian hiện tại $\mathbf{x}_3$.

Sự phụ thuộc về mặt tính toán của mạng nơ-ron hồi tiếp với ba bước thời gian. Ô vuông tượng trưng cho các biến (không tô đậm) hoặc các tham số (tô đậm), hình tròn tượng trưng cho các phép toán. 🏷️fig_rnn_bptt

BPTT chi tiết

Sau khi thảo luận các nguyên lý chung, hãy phân tích BPTT một cách chi tiết. Bằng cách tách $\mathbf{W}$ thành các tập ma trận trọng số khác nhau $\mathbf{W}{hx}, \mathbf{W}{hh}$ và $\mathbf{W}_{oh}$), ta thu được mô hình biến tiềm ẩn tuyến tính đơn giản:

$$\mathbf{h}t = \mathbf{W}{hx} \mathbf{x}t + \mathbf{W}{hh} \mathbf{h}_{t-1} \text{ và } \mathbf{o}t = \mathbf{W}{oh} \mathbf{h}_t.$$

Theo thảo luận ở :numref:sec_backprop, ta tính các gradient $\frac{\partial L}{\partial \mathbf{W}{hx}}$, $\frac{\partial L}{\partial \mathbf{W}{hh}}$, $\frac{\partial L}{\partial \mathbf{W}_{oh}}$ cho

$$L(\mathbf{x}, \mathbf{y}, \mathbf{W}) = \sum_{t=1}^T l(\mathbf{o}_t, y_t),$$

với $l(\cdot)$ là hàm mất mát đã chọn trước. Tính đạo hàm theo $W_{oh}$ khá đơn giản, ta có

$$\partial_{\mathbf{W}{oh}} L = \sum{t=1}^T \mathrm{prod} \left(\partial_{\mathbf{o}_t} l(\mathbf{o}_t, y_t), \mathbf{h}_t\right),$$

với $\mathrm{prod} (\cdot)$ là tích của hai hoặc nhiều ma trận.

Sự phụ thuộc vào $\mathbf{W}{hx}$ và $\mathbf{W}{hh}$ thì khó khăn hơn một chút vì cần sử dụng quy tắc dây chuyền khi tính toán đạo hàm. Ta sẽ bắt đầu với

$$\begin{aligned} \partial_{\mathbf{W}{hh}} L & = \sum{t=1}^T \mathrm{prod} \left(\partial_{\mathbf{o}t} l(\mathbf{o}t, y_t), \mathbf{W}{oh}, \partial{\mathbf{W}{hh}} \mathbf{h}t\right), \ \partial{\mathbf{W}{hx}} L & = \sum_{t=1}^T \mathrm{prod} \left(\partial_{\mathbf{o}t} l(\mathbf{o}t, y_t), \mathbf{W}{oh}, \partial{\mathbf{W}_{hx}} \mathbf{h}_t\right). \end{aligned}$$

Sau cùng, các trạng thái ẩn phụ thuộc lẫn nhau và phụ thuộc vào đầu vào quá khứ. Một đại lượng quan trọng là sư ảnh hưởng của các trạng thái ẩn quá khứ tới các trạng thái ẩn tương lai.

$$\partial_{\mathbf{h}t} \mathbf{h}{t+1} = \mathbf{W}{hh}^\top \text{ do~đó } \partial{\mathbf{h}_t} \mathbf{h}T = \left(\mathbf{W}{hh}^\top\right)^{T-t}.$$

Áp dụng quy tắc dây chuyền ta được

$$\begin{aligned} \partial_{\mathbf{W}{hh}} \mathbf{h}t & = \sum{j=1}^t \left(\mathbf{W}{hh}^\top\right)^{t-j} \mathbf{h}j \ \partial{\mathbf{W}{hx}} \mathbf{h}t & = \sum{j=1}^t \left(\mathbf{W}{hh}^\top\right)^{t-j} \mathbf{x}_j. \end{aligned}$$

Ta có thể rút ra nhiều điều từ biểu thức phức tạp này. Đầu tiên, việc lưu lại các kết quả trung gian, tức các luỹ thừa của $\mathbf{W}{hh}$ khi tính các số hạng của hàm mất mát $L$, là rất hữu ích. Thứ hai, ví dụ tuyến tính này dù đơn giản nhưng đã làm lộ ra một vấn đề chủ chốt của các mô hình chuỗi dài: ta có thể phải làm việc với các luỹ thừa rất lớn của $\mathbf{W}{hh}^j$. Trong đó, khi $j$ lớn, các trị riêng nhỏ hơn $1$ sẽ tiêu biến, còn các trị riêng lớn hơn $1$ sẽ phân kì. Các mô hình này không có tính ổn định số học, dẫn đến việc chúng quan trọng hóa quá mức các chi tiết không liên quan trong quá khứ. Một cách giải quyết vấn đề này là cắt xén các số hạng trong tổng ở một mức độ thuận tiện cho việc tính toán. Sau này ở :numref:chap_modern_rnn, ta sẽ thấy cách các mô hình chuỗi phức tạp như LSTM giải quyết vấn đề này tốt hơn. Khi lập trình, ta cắt xén các số hạng bằng cách tách rời gradient sau một số lượng bước nhất định.

Tóm tắt

  • Lan truyền ngược theo thời gian chỉ là việc áp dụng lan truyền ngược cho các mô hình chuỗi có trạng thái ẩn.
  • Việc cắt xén là cần thiết để thuận tiện cho việc tính toán và ổn định các giá trị số.
  • Luỹ thừa lớn của ma trận có thể làm các trị riêng tiêu biến hoặc phân kì, biểu hiện dưới hiện tượng tiêu biến hoặc bùng nổ gradient.
  • Để tăng hiệu năng tính toán, các giá trị trung gian được lưu lại.

Bài tập

  1. Cho ma trận đối xứng $\mathbf{M} \in \mathbb{R}^{n \times n}$ với các trị riêng $\lambda_i$. Không làm mất tính tổng quát, ta giả sử chúng được sắp xếp theo thứ tự tăng dần $\lambda_i \leq \lambda_{i+1}$. Chứng minh rằng $\mathbf{M}^k$ có các trị riêng là $\lambda_i^k$.
  2. Chứng minh rằng với vector bất kì $\mathbf{x} \in \mathbb{R}^n$, xác suất cao là $\mathbf{M}^k \mathbf{x}$ sẽ xấp xỉ vector trị riêng lớn nhất $\mathbf{v}_n$ của $\mathbf{M}$.
  3. Kết quả trên có ý nghĩa như thế nào khi tính gradient của mạng nơ-ron hồi tiếp?
  4. Ngoài gọt gradient, có phương pháp nào để xử lý bùng nổ gradient trong mạng nơ-ron hồi tiếp không?

Thảo luận

Những người thực hiện

Bản dịch trong trang này được thực hiện bởi:

  • Đoàn Võ Duy Thanh
  • Nguyễn Văn Quang
  • Lê Khắc Hồng Phúc
  • Nguyễn Văn Cường
  • Phạm Minh Đức
  • Phạm Hồng Vinh