Skip to content

Commit

Permalink
Merge commit '2975f539ff8ac9b8e07fb2b610bd69a1596d4c3c'
Browse files Browse the repository at this point in the history
  • Loading branch information
soumith committed Dec 30, 2016
2 parents 9a40821 + 2975f53 commit d42eadf
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 124 deletions.
3 changes: 3 additions & 0 deletions torch/lib/THC/THCCachingHostAllocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ struct HostAllocator
return cudaSuccess;
}

// note that cudaHostAlloc may not touch pointer if size is 0
*ptr = 0;

// allocate a new block if no cached allocation is found
err = cudaHostAlloc(ptr, size, cudaHostAllocDefault);
if (err != cudaSuccess) {
Expand Down
14 changes: 12 additions & 2 deletions torch/lib/THC/THCGenerateHalfType.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,32 @@

#include "THCHalf.h"

#ifdef CUDA_HALF_TENSOR
#if defined(CUDA_HALF_TENSOR) || defined(FORCE_TH_HALF)

#define real half
#define accreal float
#define Real Half

// if only here via FORCE_TH_HALF, don't define CReal since
// FORCE_TH_HALF should only be used for TH types
#ifdef CUDA_HALF_TENSOR
#define CReal CudaHalf
#endif

#define THC_REAL_IS_HALF
#line 1 THC_GENERIC_FILE
#include THC_GENERIC_FILE
#undef real
#undef accreal
#undef Real

#ifdef CUDA_HALF_TENSOR
#undef CReal
#endif

#undef THC_REAL_IS_HALF

#endif // CUDA_HALF_TENSOR
#endif // defined(CUDA_HALF_TENSOR) || defined(FORCE_TH_HALF)

#ifndef THCGenerateAllTypes
#ifndef THCGenerateFloatTypes
Expand Down
151 changes: 74 additions & 77 deletions torch/lib/THC/THCHalf.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,93 +30,90 @@ void THCHalf2Float(THCState *state, float *out, half *in, ptrdiff_t len) {
in, in + len, out, __half2floatOp());
}

float THC_half2float(half a)
// FixMe: could call TH_half2float
// and convert types here, but maybe slower?
float THC_half2float(half h)
{
unsigned int bits = a.x & 0x7fff;
unsigned int sign = a.x & 0x8000;
unsigned int exp = a.x & 0x7c00;
unsigned sign = ((h.x >> 15) & 1);
unsigned exponent = ((h.x >> 10) & 0x1f);
unsigned mantissa = ((h.x & 0x3ff) << 13);

if (exponent == 0x1f) { /* NaN or Inf */
mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
exponent = 0xff;
} else if (!exponent) { /* Denorm or Zero */
if (mantissa) {
unsigned int msb;
exponent = 0x71;
do {
msb = (mantissa & 0x400000);
mantissa <<= 1; /* normalize */
--exponent;
} while (!msb);
mantissa &= 0x7fffff; /* 1.mantissa is implicit */
}
} else {
exponent += 0x70;
}

bits <<= 13;
sign <<= 16;
int temp = ((sign << 31) | (exponent << 23) | mantissa);

bits += 0x38000000U;
return *((float*)((void*)&temp));
}

// flush denormals to 0
bits = (exp == 0 ? 0 : bits) | sign;
half THC_float2half(float f)
{
half ret;

union {
float f;
unsigned int v;
} conv;
conv.v = bits;
unsigned x = *((int*)(void*)(&f));
unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
unsigned sign, exponent, mantissa;

return conv.f;
}
// Get rid of +NaN/-NaN case first.
if (u > 0x7f800000) {
ret.x = 0x7fffU;
return ret;
}

/*
Copyright (c) 2015, Norbert Juffa
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

half THC_float2half(float a)
{
uint32_t ia;
uint16_t ir;
memcpy(&ia, &a, sizeof(float));

ir = (ia >> 16) & 0x8000;
if ((ia & 0x7f800000) == 0x7f800000) {
if ((ia & 0x7fffffff) == 0x7f800000) {
ir |= 0x7c00; /* infinity */
} else {
ir = 0x7fff; /* canonical NaN */
}
} else if ((ia & 0x7f800000) >= 0x33000000) {
int shift = (int)((ia >> 23) & 0xff) - 127;
if (shift > 15) {
ir |= 0x7c00; /* infinity */
} else {
ia = (ia & 0x007fffff) | 0x00800000; /* extract mantissa */
if (shift < -14) { /* denormal */
ir |= ia >> (-1 - shift);
ia = ia << (32 - (-1 - shift));
} else { /* normal */
ir |= ia >> (24 - 11);
ia = ia << (32 - (24 - 11));
ir = ir + ((14 + shift) << 10);
}
/* IEEE-754 round to nearest of even */
if ((ia > 0x80000000) || ((ia == 0x80000000) && (ir & 1))) {
ir++;
}
sign = ((x >> 16) & 0x8000);

// Get rid of +Inf/-Inf, +0/-0.
if (u > 0x477fefff) {
ret.x = sign | 0x7c00U;
return ret;
}
if (u < 0x33000001) {
ret.x = (sign | 0x0000);
return ret;
}

exponent = ((u >> 23) & 0xff);
mantissa = (u & 0x7fffff);

if (exponent > 0x70) {
shift = 13;
exponent -= 0x70;
} else {
shift = 0x7e - exponent;
exponent = 0;
mantissa |= 0x800000;
}
lsb = (1 << shift);
lsb_s1 = (lsb >> 1);
lsb_m1 = (lsb - 1);

// Round to nearest even.
remainder = (mantissa & lsb_m1);
mantissa >>= shift;
if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
++mantissa;
if (!(mantissa & 0x3ff)) {
++exponent;
mantissa = 0;
}
}

half ret;
memcpy(&ret, &ir, sizeof(half));
ret.x = (sign | (exponent << 10) | mantissa);
return ret;
}

Expand Down
6 changes: 2 additions & 4 deletions torch/lib/THC/generic/THCStorageCopy.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
#define THC_GENERIC_FILE "generic/THCStorageCopy.c"
#else

#ifndef THC_REAL_IS_HALF
void THCStorage_(copyCPU)(THCState *state, THCStorage *self, struct THStorage *src)
{
THArgCheck(self->size == src->size, 2, "size does not match");
THCudaCheck(cudaMemcpy(self->data, src->data, self->size * sizeof(real), cudaMemcpyHostToDevice));
}
#endif

#define TH_CUDA_STORAGE_IMPLEMENT_COPY(TYPEC) \
void THCStorage_(copy##TYPEC)(THCState *state, THCStorage *self, struct TH##TYPEC##Storage *src) \
Expand All @@ -27,15 +25,14 @@ TH_CUDA_STORAGE_IMPLEMENT_COPY(Short)
TH_CUDA_STORAGE_IMPLEMENT_COPY(Int)
TH_CUDA_STORAGE_IMPLEMENT_COPY(Long)
TH_CUDA_STORAGE_IMPLEMENT_COPY(Float)
TH_CUDA_STORAGE_IMPLEMENT_COPY(Half)
TH_CUDA_STORAGE_IMPLEMENT_COPY(Double)

#ifndef THC_REAL_IS_HALF
void THStorage_(copyCuda)(THCState *state, THStorage *self, struct THCStorage *src)
{
THArgCheck(self->size == src->size, 2, "size does not match");
THCudaCheck(cudaMemcpy(self->data, src->data, self->size * sizeof(real), cudaMemcpyDeviceToHost));
}
#endif

#define TH_CUDA_STORAGE_IMPLEMENT_COPYTO(TYPEC) \
void TH_CONCAT_4(TH,TYPEC,Storage_copyCuda,Real)(THCState *state, TH##TYPEC##Storage *self, struct THCStorage *src) \
Expand All @@ -54,6 +51,7 @@ TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Short)
TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Int)
TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Long)
TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Float)
TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Half)
TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Double)

#undef TH_CUDA_STORAGE_IMPLEMENT_COPY
Expand Down
5 changes: 2 additions & 3 deletions torch/lib/THC/generic/THCStorageCopy.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ THC_API void THCStorage_(copyInt)(THCState *state, THCStorage *storage, struct T
THC_API void THCStorage_(copyLong)(THCState *state, THCStorage *storage, struct THLongStorage *src);
THC_API void THCStorage_(copyFloat)(THCState *state, THCStorage *storage, struct THFloatStorage *src);
THC_API void THCStorage_(copyDouble)(THCState *state, THCStorage *storage, struct THDoubleStorage *src);
THC_API void THCStorage_(copyHalf)(THCState *state, THCStorage *storage, struct THHalfStorage *src);

THC_API void THCStorage_(copyCudaByte)(THCState *state, THCStorage *storage, struct THCudaByteStorage *src);
THC_API void THCStorage_(copyCudaChar)(THCState *state, THCStorage *storage, struct THCudaCharStorage *src);
Expand All @@ -32,12 +33,10 @@ THC_API void TH_CONCAT_2(THIntStorage_copyCuda , Real)(THCState *state, THIntS
THC_API void TH_CONCAT_2(THLongStorage_copyCuda , Real)(THCState *state, THLongStorage *self, struct THCStorage *src);
THC_API void TH_CONCAT_2(THFloatStorage_copyCuda , Real)(THCState *state, THFloatStorage *self, struct THCStorage *src);
THC_API void TH_CONCAT_2(THDoubleStorage_copyCuda, Real)(THCState *state, THDoubleStorage *self, struct THCStorage *src);
THC_API void TH_CONCAT_2(THHalfStorage_copyCuda, Real)(THCState *state, THHalfStorage *self, struct THCStorage *src);

/* There is no THHalfStorage */
#ifndef THC_REAL_IS_HALF
THC_API void THStorage_(copyCuda)(THCState *state, THStorage *self, THCStorage *src);
THC_API void THCStorage_(copyCuda)(THCState *state, THCStorage *self, THCStorage *src);
THC_API void THCStorage_(copyCPU)(THCState *state, THCStorage *self, THStorage *src);
#endif

#endif
36 changes: 2 additions & 34 deletions torch/lib/THC/generic/THCTensorCopy.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

/* specific methods */

#ifndef THC_REAL_IS_HALF
void THCTensor_(copyCPU)(THCState *state, THCTensor *self, struct THTensor *src)
{
THArgCheck(THCTensor_(nElement)(state, self) == THTensor_(nElement)(src), 2, "sizes do not match");
Expand All @@ -22,9 +21,7 @@ void THCTensor_(copyCPU)(THCState *state, THCTensor *self, struct THTensor *src)
THCTensor_(freeCopyTo)(state, selfc, self);
}
}
#endif

#ifndef THC_REAL_IS_HALF
#define IMPLEMENT_TH_CUDA_TENSOR_COPY(TYPEC) \
void THCTensor_(copy##TYPEC)(THCState *state, THCTensor *self, struct TH##TYPEC##Tensor *src) \
{ \
Expand All @@ -42,19 +39,6 @@ void THCTensor_(copy##TYPEC)(THCState *state, THCTensor *self, struct TH##TYPEC#
THTensor_(free)(srcf); \
} \
}
#else
#define IMPLEMENT_TH_CUDA_TENSOR_COPY(TYPEC) \
void THCTensor_(copy##TYPEC)(THCState *state, THCTensor *self, struct TH##TYPEC##Tensor *src) \
{ \
THArgCheck(THCTensor_(nElement)(state, self) == TH##TYPEC##Tensor_nElement(src), 2, "sizes do not match"); \
THLongStorage *size = TH##TYPEC##Tensor_newSizeOf(src); \
THCudaTensor *buffer = THCudaTensor_newWithSize(state, size, NULL); \
THCudaTensor_copy##TYPEC(state, buffer, src); \
THCudaHalfTensor_copyCudaFloat(state, self, buffer); \
THCudaTensor_free(state, buffer); \
THLongStorage_free(size); \
}
#endif

IMPLEMENT_TH_CUDA_TENSOR_COPY(Byte)
IMPLEMENT_TH_CUDA_TENSOR_COPY(Char)
Expand All @@ -63,10 +47,10 @@ IMPLEMENT_TH_CUDA_TENSOR_COPY(Int)
IMPLEMENT_TH_CUDA_TENSOR_COPY(Long)
IMPLEMENT_TH_CUDA_TENSOR_COPY(Float)
IMPLEMENT_TH_CUDA_TENSOR_COPY(Double)
IMPLEMENT_TH_CUDA_TENSOR_COPY(Half)

/* copyCuda */

#ifndef THC_REAL_IS_HALF
void THTensor_(copyCuda)(THCState *state, THTensor *self, struct THCTensor *src)
{
THArgCheck(THTensor_(nElement)(self) == THCTensor_(nElement)(state, src), 2, "sizes do not match");
Expand All @@ -84,9 +68,7 @@ void THTensor_(copyCuda)(THCState *state, THTensor *self, struct THCTensor *src)
THTensor_(freeCopyTo)(selfc, self);
}
}
#endif

#ifndef THC_REAL_IS_HALF
#define IMPLEMENT_TH_CUDA_TENSOR_COPY_TO(TYPEC) \
void TH_CONCAT_4(TH,TYPEC,Tensor_copyCuda,Real)(THCState *state, TH##TYPEC##Tensor *self, struct THCTensor *src) \
{ \
Expand All @@ -104,19 +86,6 @@ void THTensor_(copyCuda)(THCState *state, THTensor *self, struct THCTensor *src)
THTensor_(free)(srcf); \
} \
}
#else
#define IMPLEMENT_TH_CUDA_TENSOR_COPY_TO(TYPEC) \
void TH_CONCAT_4(TH,TYPEC,Tensor_copyCuda,Real)(THCState *state, TH##TYPEC##Tensor *self, struct THCTensor *src) \
{ \
THArgCheck(TH##TYPEC##Tensor_nElement(self) == THCTensor_(nElement)(state, src), 2, "sizes do not match"); \
THLongStorage *size = THCTensor_(newSizeOf)(state, src); \
THCudaTensor *buffer = THCudaTensor_newWithSize(state, size, NULL); \
THCudaTensor_copyCudaHalf(state, buffer, src); \
TH_CONCAT_3(TH,TYPEC,Tensor_copyCudaFloat)(state, self, buffer); \
THCudaTensor_free(state, buffer); \
THLongStorage_free(size); \
}
#endif

IMPLEMENT_TH_CUDA_TENSOR_COPY_TO(Byte)
IMPLEMENT_TH_CUDA_TENSOR_COPY_TO(Char)
Expand All @@ -125,13 +94,13 @@ IMPLEMENT_TH_CUDA_TENSOR_COPY_TO(Int)
IMPLEMENT_TH_CUDA_TENSOR_COPY_TO(Long)
IMPLEMENT_TH_CUDA_TENSOR_COPY_TO(Float)
IMPLEMENT_TH_CUDA_TENSOR_COPY_TO(Double)
IMPLEMENT_TH_CUDA_TENSOR_COPY_TO(Half)

void THCTensor_(copyCuda)(THCState *state, THCTensor *self, THCTensor *src)
{
THCTensor_(copy)(state, self, src);
}

#ifndef THC_REAL_IS_HALF
void THCTensor_(copyAsyncCPU)(THCState *state, THCTensor *self, struct THTensor *src)
{
THArgCheck(THCTensor_(nElement)(state, self) == THTensor_(nElement)(src), 2, "sizes do not match");
Expand Down Expand Up @@ -193,7 +162,6 @@ void THTensor_(copyAsyncCuda)(THCState *state, THTensor *self, struct THCTensor
THCudaCheck(cudaSetDevice(currentDevice));
}
}
#endif

#undef IMPLEMENT_TH_CUDA_TENSOR_COPY
#undef IMPLEMENT_TH_CUDA_TENSOR_COPY_TO
Expand Down
Loading

0 comments on commit d42eadf

Please sign in to comment.