🏷️sec_bptt
Cho đến nay chúng ta liên tục nhắc đến những vấn đề như bùng nổ gradient, tiêu biến gradient, cắt xén lan truyển ngược và sự cần thiết của việc tách đồ thị tính toán.
Ví dụ, trong phần trước chúng ta gọi hàm s.detach()
trên chuỗi.
Vì muốn nhanh chóng xây dựng và quan sát cách một mô hình hoạt động nên những vấn đề này chưa được giải thích một cách đầy đủ.
Trong phần này chúng ta sẽ nghiên cứu sâu và chi tiết hơn về lan truyền ngược cho các mô hình chuỗi và giải thích nguyên lý toán học đằng sau.
Để hiểu chi tiết hơn về tính ngẫu nhiên và lan truyền ngược, hãy tham khảo bài báo :cite:Tallec.Ollivier.2017
.
Chúng ta đã thấy một vài hậu quả của bùng nổ gradient khi lập trình mạng nơ-ron hồi tiếp (:numref:sec_rnn_scratch
).
Cụ thể, nếu bạn đã làm xong bài tập ở phần đó, bạn sẽ thấy rằng việc gọt gradient đóng vai trò rất quan trọng để đảm bảo mô hình hội tụ.
Để có cái nhìn rõ hơn về vấn đề này, trong phần này chúng ta sẽ xem xét cách tính gradient cho các mô hình chuỗi.
Lưu ý rằng, về mặt khái niệm thì không có gì mới ở đây.
Sau cùng, chúng ta vẫn chỉ đơn thuần áp dụng các quy tắc dây chuyền để tính gradient.
Tuy nhiên, việc ôn lại lan truyền ngược (:numref:sec_backprop
) vẫn rất hữu ích.
Lượt truyền xuôi trong mạng nơ-ron hồi tiếp tương đối đơn giản. Lan truyền ngược qua thời gian thực chất là một ứng dụng cụ thể của lan truyền ngược trong mạng nơ-ron hồi tiếp. Nó đòi hỏi chúng ta mở rộng mạng nơ-ron hồi tiếp theo từng bước thời gian một để thu được sự phụ thuộc giữa các biến mô hình và các tham số. Sau đó, dựa trên quy tắc dây chuyền, chúng ta áp dụng lan truyền ngược để tính toán và lưu các giá trị gradient. Vì chuỗi có thể khá dài nên sự phụ thuộc trong chuỗi cũng có thể rất dài. Ví dụ, đối với một chuỗi gồm 1000 ký tự, ký tự đầu tiên có thể ảnh hưởng đáng kể tới ký tự ở vị trí 1000. Điều này không thực sự khả thi về mặt tính toán (cần quá nhiều thời gian và bộ nhớ) và nó đòi hỏi hơn 1000 phép nhân ma trận-vector trước khi thu được các giá trị gradient khó nắm bắt này. Đây là một quá trình chứa đầy sự bất định về mặt tính toán và thống kê. Trong phần tiếp theo chúng ta sẽ làm sáng tỏ những gì sẽ xảy ra và cách giải quyết vấn đề này trong thực tế.
Hãy bắt đầu với một mô hình đơn giản về cách mà mạng RNN hoạt động.
Mô hình này bỏ qua các chi tiết cụ thể của trạng thái ẩn và cách trạng thái này được cập nhật.
Những chi tiết này không quan trọng đối với việc phân tích dưới đây mà chỉ khiến các ký hiệu trở nên lộn xộn và phức tạp quá mức.
Trong mô hình đơn giản này, chúng ta ký hiệu
Do đó, chúng ta có một chuỗi các giá trị
Đối với lan truyền ngược, mọi thứ lại phức tạp hơn một chút, đặc biệt là khi chúng ta tính gradient theo các tham số
Ta có thể tính phần đầu tiên và phần thứ hai của đạo hàm một cách dễ dàng.
Phần thứ ba
Để tính được gradient ở trên, giả sử rằng chúng ta có ba chuỗi
eq_bptt_at
Bây giờ chúng ta áp dụng :eqref:eq_bptt_at
với
Vì vậy, công thức
Sử dụng :eqref:eq_bptt_at
, phần thứ ba sẽ trở thành
Dù chúng ta có thể sử dụng quy tắc dây chuyền để tính
- Tính toàn bộ tổng. Cách này rất chậm và gradient có thể bùng nổ vì những thay đổi nhỏ trong các điều kiện ban đầu cũng có khả năng ảnh hưởng đến kết quả rất nhiều. Điều này tương tự như trong hiệu ứng cánh bướm, khi những thay đổi rất nhỏ trong điều kiện ban đầu dẫn đến những thay đổi không cân xứng trong kết quả. Đây thực sự là điều không mong muốn khi xét tới mô hình mà chúng ta muốn ước lượng. Sau cùng, chúng ta đang cố tìm kiếm một bộ ước lượng mạnh mẽ và có khả năng khái quát tốt. Do đó chiến lược này hầu như không bao giờ được sử dụng trong thực tế.
-
Cắt xén tổng sau
$\tau$ bước. Cho đến giây phút hiện tại, đây là những gì chúng ta đã thảo luận. Điều này dẫn tới một phép xấp xỉ của gradient, đơn giản bằng cách kết thúc tổng trên tại$\partial_w h_{t-\tau}$ . Do đó lỗi xấp xỉ là$\partial_h f(x_t, h_{t-1}, w) \partial_w h_{t-1}$ (nhân với tích của gradient liên quan đến$\partial_h f$ ). Trong thực tế, chiến lược này hoạt động khá tốt. Phương pháp này thường được gọi là BPTT (backpropagation through time --- lan truyền ngược qua thời gian) bị cắt xén. Một trong những hệ quả của phương pháp này là mô hình sẽ tập trung chủ yếu vào ảnh hưởng ngắn hạn thay vì dài hạn. Đây thực sự là điều mà chúng ta mong muốn, vì nó hướng sự ước lượng tới các mô hình đơn giản và ổn định hơn.
-
Cắt xén Ngẫu nhiên. Cuối cùng, chúng ta có thể thay thế
$\partial_{w_h} h_t$ bằng một biến ngẫu nhiên có giá trị kỳ vọng đúng nhưng vẫn cắt xén chuỗi. - Điều này có thể đạt được bằng cách sử dụng một chuỗi các
$\xi_t$ trong đó$E[\xi_t] = 1$ ,$P(\xi_t = 0) = 1-\pi$ và$P(\xi_t = \pi^{-1}) = \pi$ . - Chúng ta sẽ sử dụng chúng thay vì gradient:
Từ định nghĩa của Tallec.Ollivier.2017
đưa ra đề xuất này trong bài báo nghiên cứu của họ.
Không may, dù phương pháp này khá hấp dẫn về mặt lý thuyết, nó lại không tốt hơn phương pháp cắt xén đơn giản, nhiều khả năng do các yếu tố sau.
Thứ nhất, tác động của một quan sát đến quá khứ sau một vài lượt lan truyền ngược đã là tương đối đủ để nắm bắt các phụ thuộc trên thực tế.
Thứ hai, phương sai tăng lên làm phản tác dụng của việc có gradient chính xác hơn.
Thứ ba, ta thực sự muốn các mô hình có khoảng tương tác ngắn.
Do đó, BPTT có một hiệu ứng điều chuẩn nhỏ mà có thể có ích.
:numref:fig_truncated_bptt
minh họa ba trường hợp trên khi phân tích một số từ đầu tiên trong Cỗ máy Thời gian:
- Dòng đầu tiên biểu diễn sự cắt xén ngẫu nhiên, chia văn bản thành các phần có độ dài biến thiên.
- Dòng thứ hai biểu diễn BPTT bị cắt xén đều, chia văn bản thành các phần có độ dài bằng nhau.
- Dòng thứ ba là BPTT đầy đủ, dẫn đến một biểu thức không khả thi về mặt tính toán.
Để minh họa trực quan sự phụ thuộc giữa các biến và tham số mô hình trong suốt quá trình tính toán của mạng nơ-ron hồi tiếp, ta có thể vẽ đồ thị tính toán của mô hình, như trong :numref:fig_rnn_bptt
.
Ví dụ, việc tính toán trạng thái ẩn ở bước thời gian 3, $\mathbf{h}3$, phụ thuộc vào các tham số $\mathbf{W}{hx}$ và
Sau khi thảo luận các nguyên lý chung, hãy phân tích BPTT một cách chi tiết.
Bằng cách tách
$$\mathbf{h}t = \mathbf{W}{hx} \mathbf{x}t + \mathbf{W}{hh} \mathbf{h}_{t-1} \text{ và } \mathbf{o}t = \mathbf{W}{oh} \mathbf{h}_t.$$
Theo thảo luận ở :numref:sec_backprop
, ta tính các gradient $\frac{\partial L}{\partial \mathbf{W}{hx}}$, $\frac{\partial L}{\partial \mathbf{W}{hh}}$,
với
$$\partial_{\mathbf{W}{oh}} L = \sum{t=1}^T \mathrm{prod} \left(\partial_{\mathbf{o}_t} l(\mathbf{o}_t, y_t), \mathbf{h}_t\right),$$
với
Sự phụ thuộc vào $\mathbf{W}{hx}$ và $\mathbf{W}{hh}$ thì khó khăn hơn một chút vì cần sử dụng quy tắc dây chuyền khi tính toán đạo hàm. Ta sẽ bắt đầu với
$$\begin{aligned} \partial_{\mathbf{W}{hh}} L & = \sum{t=1}^T \mathrm{prod} \left(\partial_{\mathbf{o}t} l(\mathbf{o}t, y_t), \mathbf{W}{oh}, \partial{\mathbf{W}{hh}} \mathbf{h}t\right), \ \partial{\mathbf{W}{hx}} L & = \sum_{t=1}^T \mathrm{prod} \left(\partial_{\mathbf{o}t} l(\mathbf{o}t, y_t), \mathbf{W}{oh}, \partial{\mathbf{W}_{hx}} \mathbf{h}_t\right). \end{aligned}$$
Sau cùng, các trạng thái ẩn phụ thuộc lẫn nhau và phụ thuộc vào đầu vào quá khứ. Một đại lượng quan trọng là sư ảnh hưởng của các trạng thái ẩn quá khứ tới các trạng thái ẩn tương lai.
$$\partial_{\mathbf{h}t} \mathbf{h}{t+1} = \mathbf{W}{hh}^\top \text{ do~đó } \partial{\mathbf{h}_t} \mathbf{h}T = \left(\mathbf{W}{hh}^\top\right)^{T-t}.$$
Áp dụng quy tắc dây chuyền ta được
$$\begin{aligned} \partial_{\mathbf{W}{hh}} \mathbf{h}t & = \sum{j=1}^t \left(\mathbf{W}{hh}^\top\right)^{t-j} \mathbf{h}j \ \partial{\mathbf{W}{hx}} \mathbf{h}t & = \sum{j=1}^t \left(\mathbf{W}{hh}^\top\right)^{t-j} \mathbf{x}_j. \end{aligned}$$
Ta có thể rút ra nhiều điều từ biểu thức phức tạp này.
Đầu tiên, việc lưu lại các kết quả trung gian, tức các luỹ thừa của $\mathbf{W}{hh}$ khi tính các số hạng của hàm mất mát $L$, là rất hữu ích.
Thứ hai, ví dụ tuyến tính này dù đơn giản nhưng đã làm lộ ra một vấn đề chủ chốt của các mô hình chuỗi dài: ta có thể phải làm việc với các luỹ thừa rất lớn của $\mathbf{W}{hh}^j$.
Trong đó, khi chap_modern_rnn
, ta sẽ thấy cách các mô hình chuỗi phức tạp như LSTM giải quyết vấn đề này tốt hơn.
Khi lập trình, ta cắt xén các số hạng bằng cách tách rời gradient sau một số lượng bước nhất định.
- Lan truyền ngược theo thời gian chỉ là việc áp dụng lan truyền ngược cho các mô hình chuỗi có trạng thái ẩn.
- Việc cắt xén là cần thiết để thuận tiện cho việc tính toán và ổn định các giá trị số.
- Luỹ thừa lớn của ma trận có thể làm các trị riêng tiêu biến hoặc phân kì, biểu hiện dưới hiện tượng tiêu biến hoặc bùng nổ gradient.
- Để tăng hiệu năng tính toán, các giá trị trung gian được lưu lại.
- Cho ma trận đối xứng
$\mathbf{M} \in \mathbb{R}^{n \times n}$ với các trị riêng$\lambda_i$ . Không làm mất tính tổng quát, ta giả sử chúng được sắp xếp theo thứ tự tăng dần$\lambda_i \leq \lambda_{i+1}$ . Chứng minh rằng$\mathbf{M}^k$ có các trị riêng là$\lambda_i^k$ . - Chứng minh rằng với vector bất kì
$\mathbf{x} \in \mathbb{R}^n$ , xác suất cao là$\mathbf{M}^k \mathbf{x}$ sẽ xấp xỉ vector trị riêng lớn nhất$\mathbf{v}_n$ của$\mathbf{M}$ . - Kết quả trên có ý nghĩa như thế nào khi tính gradient của mạng nơ-ron hồi tiếp?
- Ngoài gọt gradient, có phương pháp nào để xử lý bùng nổ gradient trong mạng nơ-ron hồi tiếp không?
Bản dịch trong trang này được thực hiện bởi:
- Đoàn Võ Duy Thanh
- Nguyễn Văn Quang
- Lê Khắc Hồng Phúc
- Nguyễn Văn Cường
- Phạm Minh Đức
- Phạm Hồng Vinh