forked from TuringLang/Turing.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
42 changed files
with
2,069 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Copyright (c) 2016 Hong Ge, Adam Scibior | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,84 @@ | ||
# Turing.jl | ||
# Turing.jl | ||
Turing is a Julia library for probabilistic programming. A Turing probabilistic program is just a normal Julia program, wrapped in a `@model` macro, that uses some of the special macros listed below. Available inference methods include Importance Sampling, Sequential Monte Carlo, Particle Gibbs. | ||
|
||
### Example | ||
```julia | ||
@model gaussdemo begin | ||
# Define a simple Normal model with unknown mean and variance. | ||
@assume s ~ InverseGamma(2,3) | ||
@assume m ~ Normal(0,sqrt(s)) | ||
@observe 1.5 ~ Normal(m, sqrt(s)) | ||
@observe 2.0 ~ Normal(m, sqrt(s)) | ||
@predict s m | ||
end | ||
``` | ||
|
||
## Installation | ||
|
||
You will need Julia 0.4, which you can get from the official Julia [website](http://julialang.org/downloads/). We recommend that you install a pre-compiled package, as Turing may not work correctly with Julia built form source. | ||
|
||
Inside Julia, run the following: | ||
|
||
```julia | ||
Pkg.clone("https://github.com/yebai/Turing.jl") | ||
Pkg.build("Turing") | ||
Pkg.test("Turing") | ||
``` | ||
|
||
If all tests pass, you're ready to start using Turing. | ||
|
||
## Modelling API | ||
A probabilistic program is Julia code wrapped in a `@model` macro. It can use arbitrary Julia code, but to ensure correctness of inference it should not have external effects or modify global state. Stack-allocated variables are safe, but mutable heap-allocated objects may lead to subtle bugs when using task copying. To help avoid those we provide a Turing-safe datatype `TArray` that can be used to create mutable arrays in Turing programs. | ||
|
||
For probabilistic effects, Turing programs should use the following macros: | ||
|
||
`@assume x ~ distr` | ||
where `x` is a symbol and `distr` is a distribution. Inside the probabilistic program this puts a random variable named `x`, distributed according to `distr`, in the current scope. `distr` can be a value of any type that implements `rand(distr)`, which samples a value from the distribution `distr`. | ||
|
||
`@observe y ~ distr` | ||
This is used for conditioning in a style similar to Anglican. Here `y` should be a value that is observed to have been drawn from the distribution `distr`. The likelihood is computed using `pdf(distr,y)` and should always be positive to ensure correctness of inference algorithms. The observe statements should be arranged so that every possible run traverses all of them in exactly the same order. This is equivalent to demanding that they are not placed inside stochastic control flow. | ||
|
||
`@predict x` | ||
Registers the current value of `x` to be inspected in the results of inference. | ||
|
||
## Inference API | ||
Inference methods are functions which take the probabilistic program as one of the arguments. | ||
```julia | ||
# run sampler, collect results | ||
chain = sample(gaussdemo, SMC(500)) | ||
chain = sample(gaussdemo, PG(10,500)) | ||
``` | ||
|
||
## Task copying | ||
Turing [copies](https://github.com/JuliaLang/julia/issues/4085) Julia tasks to deliver efficient inference algorithms, but it also provides alternative slower implementation as a fallback. Task copying is enabled by default. Task copying requires building a small C program, which should be done automatically on Linux and Mac systems that have GCC and Make installed. | ||
|
||
## Development notes | ||
Following GitHub guidelines, we have two main branches: master and development. We protect them with a review process to make sure at least two people approve the code before it is committed to either of them. For this reason, do not commit to either master or development directly. | ||
Please use the following workflow instead. | ||
|
||
### Reporting bugs | ||
- branch from master | ||
- write a test that exposes the bug | ||
- create an issue describing the bug, referencing the new branch | ||
|
||
### Bug fixes | ||
- assign yourself to the relevant issue | ||
- fix the bug on the dedicated branch | ||
- see that the new test (and all the old ones) passes | ||
- create a pull request | ||
- when a pull request is accepted close the issue and propagate changes to development | ||
|
||
### New features and performance enhancements | ||
- create an issue describing the proposed change | ||
- branch from development | ||
- write tests for the new features | ||
- implement the feature | ||
- see that all tests pass | ||
- create a pull request | ||
|
||
### Merging development into master | ||
- review changes from master | ||
- create a pull request | ||
|
||
## External contributions | ||
Turing is an open-source project and we welcome any and all contributions from the community. If you have comments, questions, suggestions, or bug reports, please open an issue and let us know. If you want to contribute bug fixes or new features, please fork the repo and make a pull requrest. If you'd like to make a more substantial contribution to Turing, please get in touch so we can discuss the best way to proceed. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
julia 0.4 | ||
|
||
Distributions | ||
ConjugatePriors |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
CC=gcc | ||
JL_SHARE = $(shell julia -e 'print(joinpath(JULIA_HOME,Base.DATAROOTDIR,"julia"))') | ||
CFLAGS += $(shell julia $(JL_SHARE)/julia-config.jl --cflags) | ||
CXXFLAGS += $(shell julia $(JL_SHARE)/julia-config.jl --cflags) | ||
LDFLAGS += $(shell julia $(JL_SHARE)/julia-config.jl --ldflags) | ||
LDLIBS += $(shell julia $(JL_SHARE)/julia-config.jl --ldlibs) | ||
|
||
task: task.c | ||
$(CC) $(CFLAGS) -O2 -shared -fPIC task.c $(LDFLAGS) $(LDLIBS) -o libtask.$(LIBEXT) | ||
|
||
platform=$(shell uname) | ||
ifeq ($(platform),Linux) | ||
LIBEXT=so | ||
else ifeq ($(platform),Darwin) | ||
LIBEXT=dylib | ||
else | ||
LIBEXT=dll | ||
endif | ||
|
||
clean: | ||
-rm libtask.$(LIBEXT) | ||
|
||
.PHONY: \ | ||
task | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
run(`make`) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
/* | ||
task.c | ||
lightweight processes (symmetric coroutines) | ||
*/ | ||
|
||
#include "julia.h" | ||
|
||
jl_task_t *jl_clone_task(jl_task_t *t) | ||
{ | ||
jl_task_t *newt = (jl_task_t*)jl_gc_allocobj(sizeof(jl_task_t)); | ||
jl_set_typeof(newt, jl_task_type); | ||
newt->stkbuf = NULL; | ||
newt->gcstack = NULL; | ||
JL_GC_PUSH1(&newt); | ||
|
||
newt->parent = t->parent; | ||
newt->last = t->last; | ||
newt->current_module = t->current_module; | ||
newt->state = t->state; | ||
newt->start = t->start; | ||
newt->tls = jl_nothing; | ||
newt->consumers = jl_nothing; | ||
newt->result = jl_nothing; | ||
newt->donenotify = jl_nothing; | ||
newt->exception = jl_nothing; | ||
newt->backtrace = jl_nothing; | ||
newt->eh = NULL; | ||
newt->gcstack = t->gcstack; | ||
|
||
/* | ||
jl_printf(JL_STDOUT,"t: %p\n", t); | ||
jl_printf(JL_STDOUT,"t->stkbuf: %p\n", t->stkbuf); | ||
jl_printf(JL_STDOUT,"t->gcstack: %p\n", t->gcstack); | ||
jl_printf(JL_STDOUT,"t->bufsz: %zu\n", t->bufsz); | ||
*/ | ||
|
||
memcpy((void*)newt->ctx, (void*)t->ctx, sizeof(jl_jmp_buf)); | ||
#ifdef COPY_STACKS | ||
if (t->stkbuf){ | ||
newt->ssize = t->ssize; // size of saved piece | ||
// newt->stkbuf = allocb(t->bufsz); // needs to be allocb(t->bufsz) | ||
// newt->bufsz = t->bufsz; | ||
// memcpy(newt->stkbuf, t->stkbuf, t->bufsz); | ||
// workaround, newt and t will get new stkbuf when savestack is called. | ||
t->bufsz = 0; | ||
newt->bufsz = 0; | ||
newt->stkbuf = t->stkbuf; | ||
}else{ | ||
newt->ssize = 0; | ||
newt->bufsz = 0; | ||
newt->stkbuf = NULL; | ||
} | ||
#else | ||
#error task copying not supported yet. | ||
#endif | ||
JL_GC_POP(); | ||
jl_gc_wb_back(newt); | ||
|
||
return newt; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
This directory contains Turing models. Please use the following conventions: | ||
- only one model per file, use the same name for the file and the model | ||
- all model files should contain the line "using Turing" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
using Turing | ||
|
||
K = 10 | ||
N = 51 | ||
|
||
T = Array{Array{Float64}}( | ||
{{ 2.72545164e-01, 6.56376310e-37, 7.93970182e-06, | ||
8.61092244e-05, 1.65656847e-01, 2.87081146e-01, | ||
2.18659173e-04, 4.06806091e-13, 2.74399041e-01, | ||
5.09466694e-06}, | ||
{ 8.35696383e-02, 7.05396766e-06, 2.06100213e-01, | ||
1.52600546e-11, 1.53908757e-08, 6.52698080e-11, | ||
6.73013596e-01, 1.93819808e-02, 1.79275025e-02, | ||
3.22727890e-18}, | ||
{ 8.86728867e-01, 8.63212428e-07, 3.94046562e-11, | ||
1.10467781e-01, 1.11735715e-06, 2.12946427e-18, | ||
5.16880247e-08, 2.36340361e-04, 1.11600188e-03, | ||
1.44897701e-03}, | ||
{ 8.16557898e-11, 1.68992485e-02, 6.27227283e-14, | ||
1.07672980e-02, 4.26756301e-01, 1.78050405e-09, | ||
6.41316248e-05, 3.21159984e-01, 3.69047216e-02, | ||
1.87448314e-01}, | ||
{ 1.30652141e-01, 4.70465408e-06, 4.73601393e-04, | ||
3.78509164e-02, 9.06618543e-10, 7.78622816e-04, | ||
2.93837790e-04, 8.29914494e-01, 8.69688804e-26, | ||
3.16817317e-05}, | ||
{ 2.23054659e-01, 2.88152163e-01, 7.35806925e-19, | ||
1.85562602e-02, 1.73073908e-08, 4.00936069e-01, | ||
1.17437994e-12, 4.43974641e-02, 2.49033671e-02, | ||
7.22022850e-18}, | ||
{ 4.78064507e-16, 3.20444079e-01, 3.85904296e-03, | ||
1.26156421e-09, 6.88364264e-03, 4.47186979e-03, | ||
1.56660567e-01, 1.69796226e-01, 3.33780163e-01, | ||
4.10440864e-03}, | ||
{ 4.84470444e-13, 2.50251630e-14, 2.78748146e-07, | ||
2.45132866e-03, 3.03033036e-12, 2.84425237e-03, | ||
8.49830551e-07, 6.60111797e-08, 9.94702713e-01, | ||
5.11768273e-07}, | ||
{ 1.68396450e-01, 1.80280379e-07, 5.68958062e-11, | ||
1.34838199e-01, 2.08104310e-04, 8.61188042e-02, | ||
5.17409105e-02, 3.61825373e-01, 3.31239961e-11, | ||
1.96871978e-01}, | ||
{ 4.74764005e-02, 1.16126593e-06, 5.96036112e-08, | ||
1.28470373e-03, 1.30134792e-06, 3.74283978e-02, | ||
3.10068428e-01, 2.27075277e-19, 6.47484474e-03, | ||
5.97264703e-01}}) | ||
|
||
obs = Array{Float64}( | ||
{ 0.0, 7.72711051, 2.76189162, 8.8216901 , | ||
10.80174329, 8.87655587, 0.47685358, 9.51892527, | ||
7.82538035, 5.52629325, 10.75167786, 5.94925434, | ||
-0.96912603, 1.65160838, 1.65005965, -0.99642713, | ||
7.37803004, 5.40821392, 9.44046498, 8.51761132, | ||
9.76981763, 5.980154 , 9.19558142, 5.33965621, | ||
6.2388448 , 2.77755879, 6.67731151, 8.52411613, | ||
11.31057577, 8.11554144, 6.64705471, 8.02025435, | ||
9.84003587, 3.03943679, -2.93966727, 2.04372567, | ||
-0.93734763, 3.66943525, 6.12876571, -2.07758649, | ||
1.10420963, -0.23197037, 3.64908206, 14.14671815, | ||
6.96651114, 7.28554932, 9.06049355, 6.54246834, | ||
11.22672275, 7.41962631, 8.45635411}) | ||
|
||
means = zeros(Float64,K) | ||
initial = fill(1.0 / K, K) | ||
for i = 1:K | ||
means[i] = i | ||
end | ||
|
||
@model big_hmm begin | ||
states = tzeros(Int64,N) | ||
@assume states[1] ~ Categorical(initial) | ||
for i = 2:N | ||
@assume states[i] ~ Categorical(T[states[i-1]]) | ||
@observe obs[i] ~ Normal(means[states[i]], 4) | ||
end | ||
@predict states | ||
end | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
using Turing | ||
|
||
@model gaussdemo begin | ||
@assume s ~ InverseGamma(2,3) | ||
@assume m ~ Normal(0,sqrt(s)) | ||
@observe 1.5 ~ Normal(m, sqrt(s)) | ||
@observe 2.0 ~ Normal(m, sqrt(s)) | ||
@predict s m | ||
end | ||
|
||
# Sample and print. | ||
res = sample(gaussdemo, SMC(10000)) | ||
println("Infered: m = $(mean(res[:m])), s = $(mean(res[:s]))") | ||
|
||
# Compute analytical solution. Requires `ConjugatePriors` package. | ||
exact = posterior(NormalInverseGamma(0,1,2,3), Normal, [1.5,2.0]) | ||
println("Exact: m = $(mean(exact)[1]), s = $(mean(exact)[2])") |
Oops, something went wrong.