123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- #!/usr/bin/env python
- # Copyright Jim Bosch & Ankit Daftery 2010-2012.
- # Distributed under the Boost Software License, Version 1.0.
- # (See accompanying file LICENSE_1_0.txt or copy at
- # http://www.boost.org/LICENSE_1_0.txt)
- import ndarray_ext
- import unittest
- import numpy
- class TestNdarray(unittest.TestCase):
-
- def testNdzeros(self):
- for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128):
- v = numpy.zeros(60, dtype=dtp)
- dt = numpy.dtype(dtp)
- for shape in ((60,),(6,10),(4,3,5),(2,2,3,5)):
- a1 = ndarray_ext.zeros(shape,dt)
- a2 = v.reshape(a1.shape)
- self.assertEqual(shape,a1.shape)
- self.assert_((a1 == a2).all())
- def testNdzeros_matrix(self):
- for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128):
- dt = numpy.dtype(dtp)
- shape = (6, 10)
- a1 = ndarray_ext.zeros_matrix(shape, dt)
- a2 = numpy.matrix(numpy.zeros(shape, dtype=dtp))
- self.assertEqual(shape,a1.shape)
- self.assert_((a1 == a2).all())
- self.assertEqual(type(a1), type(a2))
- def testNdarray(self):
- a = range(0,60)
- for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128):
- v = numpy.array(a, dtype=dtp)
- dt = numpy.dtype(dtp)
- a1 = ndarray_ext.array(a)
- a2 = ndarray_ext.array(a,dt)
- self.assert_((a1 == v).all())
- self.assert_((a2 == v).all())
- for shape in ((60,),(6,10),(4,3,5),(2,2,3,5)):
- a1 = a1.reshape(shape)
- self.assertEqual(shape,a1.shape)
- a2 = a2.reshape(shape)
- self.assertEqual(shape,a2.shape)
- def testNdempty(self):
- for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128):
- dt = numpy.dtype(dtp)
- for shape in ((60,),(6,10),(4,3,5),(2,2,3,5)):
- a1 = ndarray_ext.empty(shape,dt)
- a2 = ndarray_ext.c_empty(shape,dt)
- self.assertEqual(shape,a1.shape)
- self.assertEqual(shape,a2.shape)
- def testTranspose(self):
- for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128):
- dt = numpy.dtype(dtp)
- for shape in ((6,10),(4,3,5),(2,2,3,5)):
- a1 = numpy.empty(shape,dt)
- a2 = a1.transpose()
- a1 = ndarray_ext.transpose(a1)
- self.assertEqual(a1.shape,a2.shape)
- def testSqueeze(self):
- a1 = numpy.array([[[3,4,5]]])
- a2 = a1.squeeze()
- a1 = ndarray_ext.squeeze(a1)
- self.assertEqual(a1.shape,a2.shape)
- def testReshape(self):
- a1 = numpy.empty((2,2))
- a2 = ndarray_ext.reshape(a1,(1,4))
- self.assertEqual(a2.shape,(1,4))
- def testShapeIndex(self):
- a = numpy.arange(24)
- a.shape = (1,2,3,4)
- def shape_check(i):
- print(i)
- self.assertEqual(ndarray_ext.shape_index(a,i) ,a.shape[i] )
- for i in range(4):
- shape_check(i)
- for i in range(-1,-5,-1):
- shape_check(i)
- try:
- ndarray_ext.shape_index(a,4) # out of bounds -- should raise IndexError
- self.assertTrue(False)
- except IndexError:
- pass
- def testStridesIndex(self):
- a = numpy.arange(24)
- a.shape = (1,2,3,4)
- def strides_check(i):
- print(i)
- self.assertEqual(ndarray_ext.strides_index(a,i) ,a.strides[i] )
- for i in range(4):
- strides_check(i)
- for i in range(-1,-5,-1):
- strides_check(i)
- try:
- ndarray_ext.strides_index(a,4) # out of bounds -- should raise IndexError
- self.assertTrue(False)
- except IndexError:
- pass
- if __name__=="__main__":
- unittest.main()
|