From a18c4383520badb0b27df1217fae5017bb3839b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vladim=C3=ADr=20Vondru=C5=A1?= Date: Thu, 25 Jul 2019 16:57:40 +0200 Subject: [PATCH] python: so we need MutableStridedArrayView4D as well. In order to access pixel data in 3D images, of course. --- src/python/corrade/containers.cpp | 99 ++++++++++++++++++++++ src/python/corrade/test/test_containers.py | 78 +++++++++++++++++ 2 files changed, 177 insertions(+) diff --git a/src/python/corrade/containers.cpp b/src/python/corrade/containers.cpp index aa2fa78..4967c64 100644 --- a/src/python/corrade/containers.cpp +++ b/src/python/corrade/containers.cpp @@ -171,6 +171,7 @@ template struct DimensionsTuple; template struct DimensionsTuple<1, T> { typedef std::tuple Type; }; template struct DimensionsTuple<2, T> { typedef std::tuple Type; }; template struct DimensionsTuple<3, T> { typedef std::tuple Type; }; +template struct DimensionsTuple<4, T> { typedef std::tuple Type; }; /* Size tuple for given dimension */ template typename DimensionsTuple::Type size(Containers::StridedDimensions); @@ -183,6 +184,9 @@ template<> std::tuple size(Containers::StridedDimensio template<> std::tuple size(Containers::StridedDimensions<3, std::size_t> size) { return std::make_tuple(size[0], size[1], size[2]); } +template<> std::tuple size(Containers::StridedDimensions<4, std::size_t> size) { + return std::make_tuple(size[0], size[1], size[2], size[3]); +} /* Stride tuple for given dimension */ template typename DimensionsTuple::Type stride(Containers::StridedDimensions); @@ -195,6 +199,9 @@ template<> std::tuple stride(Containers::Strided template<> std::tuple stride(Containers::StridedDimensions<3, std::ptrdiff_t> stride) { return std::make_tuple(stride[0], stride[1], stride[2]); } +template<> std::tuple stride(Containers::StridedDimensions<4, std::ptrdiff_t> stride) { + return std::make_tuple(stride[0], stride[1], stride[2], stride[3]); +} /* Byte conversion for given dimension */ template Containers::Array bytes(Containers::StridedArrayView); @@ -219,6 +226,15 @@ template<> Containers::Array bytes(Containers::StridedArrayView3D Containers::Array bytes(Containers::StridedArrayView<4, const char> view) { + Containers::Array out{view.size()[0]*view.size()[1]*view.size()[2]*view.size()[3]}; + std::size_t pos = 0; + for(Containers::StridedArrayView3D i: view) + for(Containers::StridedArrayView2D j: i) + for(Containers::StridedArrayView1D k: j) + for(const char l: k) out[pos++] = l; + return out; +} /* Getting a runtime tuple index. Ugh. */ template const T& dimensionsTupleGet(const typename DimensionsTuple<1, T>::Type& tuple, std::size_t i) { @@ -236,6 +252,13 @@ template const T& dimensionsTupleGet(const typename DimensionsTuple<3, if(i == 2) return std::get<2>(tuple); CORRADE_ASSERT_UNREACHABLE(); /* LCOV_EXCL_LINE */ } +template const T& dimensionsTupleGet(const typename DimensionsTuple<4, T>::Type& tuple, std::size_t i) { + if(i == 0) return std::get<0>(tuple); + if(i == 1) return std::get<1>(tuple); + if(i == 2) return std::get<2>(tuple); + if(i == 3) return std::get<3>(tuple); + CORRADE_ASSERT_UNREACHABLE(); /* LCOV_EXCL_LINE */ +} template bool stridedArrayViewBufferProtocol(T& self, Py_buffer& buffer, int flags) { if((flags & PyBUF_STRIDES) != PyBUF_STRIDES) { @@ -445,6 +468,60 @@ template void stridedArrayView3D(py::class_>& }, "Broadcast a dimension"); } +template void stridedArrayView4D(py::class_>& c) { + c + .def("__getitem__", [](const PyStridedArrayView<4, T>& self, const std::tuple& i) { + if(std::get<0>(i) >= self.size()[0] || + std::get<1>(i) >= self.size()[1] || + std::get<2>(i) >= self.size()[2] || + std::get<3>(i) >= self.size()[3]) throw pybind11::index_error{}; + return self[std::get<0>(i)][std::get<1>(i)][std::get<2>(i)][std::get<3>(i)]; + }, "Value at given position") + .def("transposed", [](const PyStridedArrayView<4, T>& self, const std::size_t a, std::size_t b) { + if((a == 0 && b == 1) || + (a == 1 && b == 0)) + return PyStridedArrayView<4, T>{self.template transposed<0, 1>(), self.obj}; + if((a == 0 && b == 2) || + (a == 2 && b == 0)) + return PyStridedArrayView<4, T>{self.template transposed<0, 2>(), self.obj}; + if((a == 0 && b == 3) || + (a == 3 && b == 0)) + return PyStridedArrayView<4, T>{self.template transposed<0, 3>(), self.obj}; + if((a == 1 && b == 2) || + (a == 2 && b == 1)) + return PyStridedArrayView<4, T>{self.template transposed<1, 2>(), self.obj}; + if((a == 1 && b == 3) || + (a == 3 && b == 1)) + return PyStridedArrayView<4, T>{self.template transposed<1, 3>(), self.obj}; + if((a == 2 && b == 3) || + (a == 3 && b == 2)) + return PyStridedArrayView<4, T>{self.template transposed<2, 3>(), self.obj}; + throw py::value_error{Utility::formatString("dimensions {}, {} can't be transposed in a {}D view", a, b, 4)}; + }, "Transpose two dimensions") + .def("flipped", [](const PyStridedArrayView<4, T>& self, const std::size_t dimension) { + if(dimension == 0) + return PyStridedArrayView<4, T>{self.template flipped<0>(), self.obj}; + if(dimension == 1) + return PyStridedArrayView<4, T>{self.template flipped<1>(), self.obj}; + if(dimension == 2) + return PyStridedArrayView<4, T>{self.template flipped<2>(), self.obj}; + if(dimension == 3) + return PyStridedArrayView<4, T>{self.template flipped<3>(), self.obj}; + throw py::value_error{Utility::formatString("dimension {} out of range for a {}D view", dimension, 4)}; + }, "Flip a dimension") + .def("broadcasted", [](const PyStridedArrayView<4, T>& self, const std::size_t dimension, std::size_t size) { + if(dimension == 0) + return PyStridedArrayView<4, T>{self.template broadcasted<0>(size), self.obj}; + if(dimension == 1) + return PyStridedArrayView<4, T>{self.template broadcasted<1>(size), self.obj}; + if(dimension == 2) + return PyStridedArrayView<4, T>{self.template broadcasted<2>(size), self.obj}; + if(dimension == 3) + return PyStridedArrayView<4, T>{self.template broadcasted<3>(size), self.obj}; + throw py::value_error{Utility::formatString("dimension {} out of range for a {}D view", dimension, 4)}; + }, "Broadcast a dimension"); +} + template void mutableStridedArrayView1D(py::class_>& c) { c .def("__setitem__", [](const PyStridedArrayView<1, T>& self, const std::size_t i, const T& value) { @@ -472,6 +549,17 @@ template void mutableStridedArrayView3D(py::class_ void mutableStridedArrayView4D(py::class_>& c) { + c + .def("__setitem__", [](const PyStridedArrayView<4, T>& self, const std::tuple& i, const T& value) { + if(std::get<0>(i) >= self.size()[0] || + std::get<1>(i) >= self.size()[1] || + std::get<2>(i) >= self.size()[2] || + std::get<3>(i) >= self.size()[3]) throw pybind11::index_error{}; + self[std::get<0>(i)][std::get<1>(i)][std::get<2>(i)][std::get<3>(i)] = value; + }, "Set a value at given position"); +} + } void containers(py::module& m) { @@ -492,6 +580,8 @@ void containers(py::module& m) { "StridedArrayView2D", "Two-dimensional array view with stride information", py::buffer_protocol{}}; py::class_> stridedArrayView3D_{m, "StridedArrayView3D", "Three-dimensional array view with stride information", py::buffer_protocol{}}; + py::class_> stridedArrayView4D_{m, + "StridedArrayView4D", "Four-dimensional array view with stride information", py::buffer_protocol{}}; stridedArrayView(stridedArrayView1D_); stridedArrayView1D(stridedArrayView1D_); stridedArrayView(stridedArrayView2D_); @@ -500,6 +590,9 @@ void containers(py::module& m) { stridedArrayView(stridedArrayView3D_); stridedArrayViewND(stridedArrayView3D_); stridedArrayView3D(stridedArrayView3D_); + stridedArrayView(stridedArrayView4D_); + stridedArrayViewND(stridedArrayView4D_); + stridedArrayView4D(stridedArrayView4D_); py::class_> mutableStridedArrayView1D_{m, "MutableStridedArrayView1D", "Mutable one-dimensional array view with stride information", py::buffer_protocol{}}; @@ -507,6 +600,8 @@ void containers(py::module& m) { "MutableStridedArrayView2D", "Mutable two-dimensional array view with stride information", py::buffer_protocol{}}; py::class_> mutableStridedArrayView3D_{m, "MutableStridedArrayView3D", "Mutable three-dimensional array view with stride information", py::buffer_protocol{}}; + py::class_> mutableStridedArrayView4D_{m, + "MutableStridedArrayView4D", "Mutable four-dimensional array view with stride information", py::buffer_protocol{}}; stridedArrayView(mutableStridedArrayView1D_); stridedArrayView1D(mutableStridedArrayView1D_); stridedArrayView(mutableStridedArrayView2D_); @@ -515,9 +610,13 @@ void containers(py::module& m) { stridedArrayView(mutableStridedArrayView3D_); stridedArrayViewND(mutableStridedArrayView3D_); stridedArrayView3D(mutableStridedArrayView3D_); + stridedArrayView(mutableStridedArrayView4D_); + stridedArrayViewND(mutableStridedArrayView4D_); + stridedArrayView4D(mutableStridedArrayView4D_); mutableStridedArrayView1D(mutableStridedArrayView1D_); mutableStridedArrayView2D(mutableStridedArrayView2D_); mutableStridedArrayView3D(mutableStridedArrayView3D_); + mutableStridedArrayView4D(mutableStridedArrayView4D_); } } diff --git a/src/python/corrade/test/test_containers.py b/src/python/corrade/test/test_containers.py index a49d19e..ee5697b 100644 --- a/src/python/corrade/test/test_containers.py +++ b/src/python/corrade/test/test_containers.py @@ -717,3 +717,81 @@ class StridedArrayView3D(unittest.TestCase): self.assertEqual(f.size, (2, 3, 5)) self.assertEqual(f.stride, (24, 8, 0)) self.assertEqual(bytes(f), b'000004444488888ccccc0000044444') + +# This is just a dumb copy of the above with one dimension inserted at the +# second place. +class StridedArrayView4D(unittest.TestCase): + def test_init_buffer(self): + a = (b'01234567' + b'456789ab' + b'89abcdef' + + b'cdef0123' + b'01234567' + b'456789ab') + b = containers.StridedArrayView4D(memoryview(a).cast('b', shape=[2, 1, 3, 8])) + self.assertEqual(len(b), 2) + self.assertEqual(bytes(b), b'01234567456789ab89abcdefcdef012301234567456789ab') + self.assertEqual(b.size, (2, 1, 3, 8)) + self.assertEqual(b.stride, (24, 24, 8, 1)) + self.assertEqual(b[1, 0, 2, 3], '7') + self.assertEqual(b[1][0][2][3], '7') + + def test_init_buffer_mutable(self): + a = bytearray(b'01234567' + b'456789ab' + b'89abcdef' + + b'cdef0123' + b'01234567' + b'456789ab') + b = containers.MutableStridedArrayView4D(memoryview(a).cast('b', shape=[2, 1, 3, 8])) + b[0, 0, 0, 7] = '!' + b[0, 0, 1, 7] = '!' + b[0, 0, 2, 7] = '!' + b[1, 0, 0, 7] = '!' + b[1, 0, 1, 7] = '!' + b[1, 0, 2, 7] = '!' + self.assertEqual(b[1][0][1][7], '!') + self.assertEqual(bytes(b), b'0123456!' + b'456789a!' + b'89abcde!' + + b'cdef012!' + b'0123456!' + b'456789a!') + + def test_ops(self): + a = (b'01234567' + b'456789ab' + b'89abcdef' + + b'cdef0123' + b'01234567' + b'456789ab') + v = memoryview(a).cast('b', shape=[2, 1, 3, 8]) + + b = containers.StridedArrayView4D(v).transposed(0, 2).flipped(0) + self.assertEqual(b.size, (3, 1, 2, 8)) + self.assertEqual(b.stride, (-8, 24, 24, 1)) + self.assertEqual(bytes(b), b'89abcdef456789ab456789ab0123456701234567cdef0123') + + c = containers.StridedArrayView4D(v).transposed(3, 0).flipped(2) + self.assertEqual(c.size, (8, 1, 3, 2)) + self.assertEqual(c.stride, (1, 24, -8, 24)) + self.assertEqual(bytes(c), b'84400c95511da6622eb7733fc88440d99551eaa662fbb773') + + d = containers.StridedArrayView4D(v).transposed(2, 3)[0:1, :, 3:5, :].broadcasted(0, 5) + self.assertEqual(d.size, (5, 1, 2, 3)) + self.assertEqual(d.stride, (0, 24, 1, 8)) + self.assertEqual(bytes(d), b'37b48c37b48c37b48c37b48c37b48c') + + e = containers.StridedArrayView4D(v)[:, :, 1:2, 3:4].flipped(3).broadcasted(2, 2) + self.assertEqual(e.size, (2, 1, 2, 1)) + self.assertEqual(e.stride, (24, 24, 0, -1)) + self.assertEqual(bytes(e), b'7733') + + f = containers.StridedArrayView4D(v)[:, :, :, 0:1].broadcasted(3, 5) + self.assertEqual(f.size, (2, 1, 3, 5)) + self.assertEqual(f.stride, (24, 24, 8, 0)) + self.assertEqual(bytes(f), b'000004444488888ccccc0000044444')