Skip to content

Commit

Permalink
updated version
Browse files Browse the repository at this point in the history
  • Loading branch information
TobiasHeOl committed Mar 15, 2024
1 parent 36efc66 commit 5eb4f31
Show file tree
Hide file tree
Showing 2 changed files with 236 additions and 19 deletions.
253 changes: 235 additions & 18 deletions notebooks/pretrained_module.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "7ae54cd0-6253-46dd-a316-4f20b12041e0",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -33,7 +33,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "99192978-a008-4a32-a80e-bba238e0ec7c",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -75,7 +75,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"id": "0e7419e4-db22-49ea-8e12-6db2b3681545",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -126,7 +126,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "ceae4a88-0679-4704-8bad-c06a4569c497",
"metadata": {},
"outputs": [],
Expand All @@ -153,10 +153,30 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"id": "d22f4302-1262-4cc1-8a1c-a36daa8c710c",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"array([[-0.2520631 , 0.18189636, 0.00887139, ..., 0.15365511,\n",
" -0.14508604, -0.1338132 ],\n",
" [-0.25149409, 0.20864547, 0.07518204, ..., 0.19478276,\n",
" -0.15227771, -0.08241642],\n",
" [-0.27468957, 0.16507224, 0.08667156, ..., 0.18776285,\n",
" -0.14165093, -0.16389883],\n",
" [-0.19822127, 0.16841082, -0.0492593 , ..., 0.11400163,\n",
" -0.14723686, -0.09713166],\n",
" [-0.29553183, 0.17239204, 0.05676914, ..., 0.15943631,\n",
" -0.16615378, -0.15569783]])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ablang(all_seqs, mode='seqcoding')"
]
Expand All @@ -177,10 +197,85 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"id": "6227f661-575f-4b1e-9646-cfba7b10c3b4",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"[array([[-0.40741202, -0.51189876, 0.06096718, ..., 0.32681432,\n",
" 0.03920227, -0.36715847],\n",
" [-0.5768882 , 0.38245377, -0.21792021, ..., 0.01250281,\n",
" -0.08844489, -0.32367533],\n",
" [-0.14759329, 0.39639032, -0.38226995, ..., -0.10119925,\n",
" -0.41469547, -0.00319326],\n",
" ...,\n",
" [-0.1435836 , 0.31243888, -0.30157977, ..., -0.13289277,\n",
" -0.45353436, -0.07878845],\n",
" [ 0.17538942, 0.24394313, 0.20141156, ..., 0.14587337,\n",
" -0.38479012, 0.07409145],\n",
" [-0.23031712, -0.354873 , 0.19606796, ..., -0.12833637,\n",
" 0.3110731 , -0.3265107 ]], dtype=float32),\n",
" array([[-0.4198183 , -0.36663735, 0.1059521 , ..., 0.39035723,\n",
" 0.0382379 , -0.36337999],\n",
" [-0.50541353, 0.38347134, -0.10992067, ..., -0.05231511,\n",
" -0.13636601, -0.34830102],\n",
" [-0.06784626, 0.69349885, -0.4212396 , ..., -0.24805343,\n",
" -0.39583787, -0.10972748],\n",
" ...,\n",
" [-0.20900953, 0.29489496, -0.11039101, ..., -0.24245393,\n",
" -0.60625213, -0.02307976],\n",
" [ 0.19134362, 0.21744648, 0.25758275, ..., 0.1584544 ,\n",
" -0.3474367 , 0.10218269],\n",
" [-0.25511587, -0.21778467, 0.21906386, ..., -0.09656096,\n",
" 0.22394848, -0.20267344]], dtype=float32),\n",
" array([[-0.40043744, -0.48596832, 0.08867243, ..., 0.38941652,\n",
" 0.06195954, -0.40999684],\n",
" [-0.5457607 , 0.43129578, -0.34514865, ..., -0.09285577,\n",
" 0.03116523, -0.45269734],\n",
" [ 0.0221168 , 0.5319657 , -0.30137247, ..., -0.18890701,\n",
" -0.3258736 , 0.05078411],\n",
" ...,\n",
" [ 0.2630384 , -0.22976035, 0.55103725, ..., 0.47436467,\n",
" -0.42733553, -0.83135855],\n",
" [-0.13752194, 0.28678605, -0.18887033, ..., 0.28262642,\n",
" 0.12546761, -0.6496488 ],\n",
" [-0.45414186, 0.24564977, 0.2132736 , ..., 0.03287451,\n",
" 0.03825564, -0.3425912 ]], dtype=float32),\n",
" array([[-0.2686321 , 0.32259196, 0.10813516, ..., 0.03953857,\n",
" 0.18312067, -0.00498033],\n",
" [-0.21654248, -0.38562426, -0.02696253, ..., 0.20541485,\n",
" 0.18698384, -0.22639509],\n",
" [-0.41950503, 0.04743315, 0.00488149, ..., 0.11408655,\n",
" -0.05384672, 0.10258742],\n",
" ...,\n",
" [-0.1096048 , 0.35151383, -0.2175244 , ..., -0.21448924,\n",
" -0.6396221 , -0.00839772],\n",
" [ 0.2049191 , 0.36294493, 0.19217433, ..., 0.07750694,\n",
" -0.50392145, 0.03793862],\n",
" [-0.11638469, -0.35350844, 0.13215733, ..., -0.16060586,\n",
" 0.2391388 , -0.25653362]], dtype=float32),\n",
" array([[-0.42062938, -0.44009122, 0.00152369, ..., 0.27141467,\n",
" 0.03798108, -0.397461 ],\n",
" [-0.5731807 , 0.52588975, -0.1700168 , ..., -0.23864639,\n",
" 0.20880571, -0.5787758 ],\n",
" [-0.38988566, 0.46168268, -0.34294134, ..., -0.14872617,\n",
" -0.4657687 , -0.21225002],\n",
" ...,\n",
" [-0.21528657, 0.3004676 , -0.2521646 , ..., -0.11576824,\n",
" -0.4704909 , -0.07401361],\n",
" [ 0.06330815, 0.22700705, 0.2818417 , ..., 0.15967268,\n",
" -0.3771821 , 0.06188553],\n",
" [-0.2782629 , -0.3729748 , 0.21229891, ..., -0.14886044,\n",
" 0.24998347, -0.35954222]], dtype=float32)]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ablang(all_seqs, mode='rescoding', stepwise_masking = False)"
]
Expand All @@ -201,10 +296,80 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"id": "e4bc0cb1-f5b0-4255-9e93-d643ae1396df",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['<' '1 ' '2 ' '3 ' '4 ' '5 ' '6 ' '7 ' '8 ' '9 ' '11 ' '12 ' '13 ' '14 '\n",
" '15 ' '16 ' '17 ' '18 ' '19 ' '20 ' '21 ' '22 ' '23 ' '24 ' '25 ' '26 '\n",
" '27 ' '28 ' '29 ' '30 ' '35 ' '36 ' '37 ' '38 ' '39 ' '40 ' '41 ' '42 '\n",
" '43 ' '44 ' '45 ' '46 ' '47 ' '48 ' '49 ' '50 ' '51 ' '52 ' '53 ' '54 '\n",
" '55 ' '56 ' '57 ' '58 ' '59 ' '62 ' '63 ' '64 ' '65 ' '66 ' '67 ' '68 '\n",
" '69 ' '70 ' '71 ' '72 ' '74 ' '75 ' '76 ' '77 ' '78 ' '79 ' '80 ' '81 '\n",
" '82 ' '83 ' '84 ' '85 ' '86 ' '87 ' '88 ' '89 ' '90 ' '91 ' '92 ' '93 '\n",
" '94 ' '95 ' '96 ' '97 ' '98 ' '99 ' '100 ' '101 ' '102 ' '103 ' '104 '\n",
" '105 ' '106 ' '107 ' '108 ' '109 ' '110 ' '111 ' '112A' '112 ' '113 '\n",
" '114 ' '115 ' '116 ' '117 ' '118 ' '119 ' '120 ' '121 ' '122 ' '123 '\n",
" '124 ' '125 ' '126 ' '127 ' '128 ' '>' '|' '<' '1 ' '2 ' '3 ' '4 ' '5 '\n",
" '6 ' '7 ' '8 ' '9 ' '10 ' '11 ' '12 ' '13 ' '14 ' '15 ' '16 ' '17 ' '18 '\n",
" '19 ' '20 ' '21 ' '22 ' '23 ' '24 ' '25 ' '26 ' '27 ' '28 ' '29 ' '30 '\n",
" '31 ' '32 ' '34 ' '35 ' '36 ' '37 ' '38 ' '39 ' '40 ' '41 ' '42 ' '43 '\n",
" '44 ' '45 ' '46 ' '47 ' '48 ' '49 ' '50 ' '51 ' '52 ' '53 ' '54 ' '55 '\n",
" '56 ' '57 ' '64 ' '65 ' '66 ' '67 ' '68 ' '69 ' '70 ' '71 ' '72 ' '74 '\n",
" '75 ' '76 ' '77 ' '78 ' '79 ' '80 ' '83 ' '84 ' '85 ' '86 ' '87 ' '88 '\n",
" '89 ' '90 ' '91 ' '92 ' '93 ' '94 ' '95 ' '96 ' '97 ' '98 ' '99 ' '100 '\n",
" '101 ' '102 ' '103 ' '104 ' '105 ' '106 ' '107 ' '108 ' '109 ' '114 '\n",
" '115 ' '116 ' '117 ' '118 ' '119 ' '120 ' '121 ' '122 ' '123 ' '124 '\n",
" '125 ' '126 ' '127 ' '>']\n",
"['<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKI-SNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>', '<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTT----->|<-----------PVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKI-SNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>', '<------SGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSS*SLEASDTNIYLSWFQQRPGQSPRRLIYKI*N-RDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>']\n",
"[[[ 9.31621552 -3.42184424 -3.59398293 ... -14.73707485 -6.8935895\n",
" -0.23662716]\n",
" [ -3.54718328 -5.8486681 -4.02423763 ... -12.9396677 -9.56145287\n",
" -4.48474121]\n",
" [-11.94997597 -2.2455442 -5.69481659 ... -15.1963892 -17.97455025\n",
" -12.56952667]\n",
" ...\n",
" [ -8.94505119 -0.42261413 -4.95588017 ... -16.66817665 -15.22247696\n",
" -10.37267685]\n",
" [-11.65150261 -5.44477367 -2.95585799 ... -16.25555801 -9.75158882\n",
" -11.75897026]\n",
" [ 1.79469967 -1.95846725 -3.59784651 ... -14.95585823 -7.47080421\n",
" -0.95226705]]\n",
"\n",
" [[ 8.55518723 -3.83663583 -2.33596039 ... -13.87456799 -8.14840603\n",
" -0.42472461]\n",
" [ -4.4070158 -5.53201628 -3.69397473 ... -12.97877884 -9.86258984\n",
" -4.95414734]\n",
" [-11.95642948 -3.86210847 -5.80935097 ... -14.89213085 -16.94556236\n",
" -11.36959457]\n",
" ...\n",
" [ -7.75924206 -0.66524088 -4.08643246 ... -16.16580582 -14.76506901\n",
" -8.35070801]\n",
" [-11.91039467 -4.86995649 -2.74777317 ... -16.07694817 -8.44974518\n",
" -10.45223522]\n",
" [ 0.86006927 -2.37964129 -3.58130884 ... -15.35423565 -7.7303524\n",
" -1.11989462]]\n",
"\n",
" [[ -4.37903118 -7.55587101 1.21958244 ... -15.48622799 -6.02184772\n",
" -3.7964797 ]\n",
" [ 0. 0. 0. ... 0. 0.\n",
" 0. ]\n",
" [ 0. 0. 0. ... 0. 0.\n",
" 0. ]\n",
" ...\n",
" [ -8.94207573 -0.51090133 -5.09760666 ... -16.69521904 -15.45450687\n",
" -10.50823021]\n",
" [-11.92355251 -5.55152798 -2.87667084 ... -16.40608025 -10.19431782\n",
" -12.13288021]\n",
" [ 2.42199802 -2.01573205 -3.61701035 ... -14.9590435 -7.19029284\n",
" -0.89830101]]]\n"
]
}
],
"source": [
"results = ablang(only_both_chains_seqs, mode='likelihood', align=True)\n",
"\n",
Expand Down Expand Up @@ -233,21 +398,45 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"id": "83f3064b-48a7-42fb-ba82-ec153ea946da",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"array([1.995193 , 2.017602 , 2.1375413, 1.8546418, 2.0021744],\n",
" dtype=float32)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"results = ablang(all_seqs, mode='pseudo_log_likelihood')\n",
"np.exp(-results) # convert to pseudo perplexity"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"id": "42cc8b34-5ae9-4857-93fe-a438a0f2a868",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"array([1.2699332, 1.1272193, 1.3212233, 1.2203734, 1.1848254],\n",
" dtype=float32)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"results = ablang(all_seqs, mode='confidence')\n",
"np.exp(-results)"
Expand All @@ -265,21 +454,49 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"id": "2d5b725c-4eac-4a4b-9331-357c3ac140f7",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"array(['<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>',\n",
" '<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTT>|<PVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>',\n",
" '<EVQLVQSGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDPPGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>'],\n",
" dtype='<U238')"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"restored = ablang(only_both_chains_seqs, mode='restore')\n",
"restored"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"id": "0e9615f7-c490-4947-96f4-7617266c686e",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"array(['<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>',\n",
" '<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS>|<DVVMTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>',\n",
" '<QVQLVQSGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDPPGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>'],\n",
" dtype='<U238')"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"restored = ablang(only_both_chains_seqs, mode='restore', align = True)\n",
"restored"
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

setup(
name='ablang2',
version='0.1.0',
version='0.1.1',
license='BSD 3-clause license',
description='AbLang2: An antibody-specific language model focusing on NGL prediction.',
long_description=long_description,
Expand Down

0 comments on commit 5eb4f31

Please sign in to comment.