36 #ifndef VIGRA_NUMPY_ARRAY_TAGGEDSHAPE_HXX 37 #define VIGRA_NUMPY_ARRAY_TAGGEDSHAPE_HXX 39 #ifndef NPY_NO_DEPRECATED_API 40 # define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION 44 #include "array_vector.hxx" 45 #include "python_utility.hxx" 46 #include "axistags.hxx" 53 python_ptr getArrayTypeObject()
55 python_ptr arraytype((PyObject*)&PyArray_Type);
56 python_ptr
vigra(PyImport_ImportModule(
"vigra"));
59 return pythonGetAttr(
vigra,
"standardArrayType", arraytype);
63 std::string defaultOrder(std::string defaultValue =
"C")
65 python_ptr arraytype = getArrayTypeObject();
66 return pythonGetAttr(arraytype,
"defaultOrder", defaultValue);
70 python_ptr defaultAxistags(
int ndim, std::string order =
"")
73 order = defaultOrder();
74 python_ptr arraytype = getArrayTypeObject();
75 python_ptr func(PyString_FromString(
"defaultAxistags"), python_ptr::keep_count);
76 python_ptr d(PyInt_FromLong(ndim), python_ptr::keep_count);
77 python_ptr o(PyString_FromString(order.c_str()), python_ptr::keep_count);
78 python_ptr axistags(PyObject_CallMethodObjArgs(arraytype, func.get(), d.get(), o.get(), NULL),
79 python_ptr::keep_count);
87 python_ptr emptyAxistags(
int ndim)
89 python_ptr arraytype = getArrayTypeObject();
90 python_ptr func(PyString_FromString(
"_empty_axistags"), python_ptr::keep_count);
91 python_ptr d(PyInt_FromLong(ndim), python_ptr::keep_count);
92 python_ptr axistags(PyObject_CallMethodObjArgs(arraytype, func.get(), d.get(), NULL),
93 python_ptr::keep_count);
102 getAxisPermutationImpl(ArrayVector<npy_intp> & permute,
103 python_ptr
object,
const char * name,
104 AxisInfo::AxisType type,
bool ignoreErrors)
106 python_ptr func(PyString_FromString(name), python_ptr::keep_count);
107 python_ptr t(PyInt_FromLong((
long)type), python_ptr::keep_count);
108 python_ptr permutation(PyObject_CallMethodObjArgs(
object, func.get(), t.get(), NULL),
109 python_ptr::keep_count);
110 if(!permutation && ignoreErrors)
115 pythonToCppException(permutation);
117 if(!PySequence_Check(permutation))
121 std::string message = std::string(name) +
"() did not return a sequence.";
122 PyErr_SetString(PyExc_ValueError, message.c_str());
123 pythonToCppException(
false);
126 ArrayVector<npy_intp> res(PySequence_Length(permutation));
127 for(
int k=0; k<(int)res.size(); ++k)
129 python_ptr i(PySequence_GetItem(permutation, k), python_ptr::keep_count);
134 std::string message = std::string(name) +
"() did not return a sequence of int.";
135 PyErr_SetString(PyExc_ValueError, message.c_str());
136 pythonToCppException(
false);
138 res[k] = PyInt_AsLong(i);
145 getAxisPermutationImpl(ArrayVector<npy_intp> & permute,
146 python_ptr
object,
const char * name,
bool ignoreErrors)
148 getAxisPermutationImpl(permute,
object, name, AxisInfo::AllAxes, ignoreErrors);
167 typedef PyObject * pointer;
171 PyAxisTags(python_ptr tags = python_ptr(),
bool createCopy =
false)
176 if(!PySequence_Check(tags))
178 PyErr_SetString(PyExc_TypeError,
179 "PyAxisTags(tags): tags argument must have type 'AxisTags'.");
180 pythonToCppException(
false);
182 else if(PySequence_Length(tags) == 0)
189 python_ptr func(PyString_FromString(
"__copy__"), python_ptr::keep_count);
190 axistags = python_ptr(PyObject_CallMethodObjArgs(tags, func.get(), NULL),
191 python_ptr::keep_count);
199 PyAxisTags(PyAxisTags
const & other,
bool createCopy =
false)
205 python_ptr func(PyString_FromString(
"__copy__"), python_ptr::keep_count);
206 axistags = python_ptr(PyObject_CallMethodObjArgs(other.axistags, func.get(), NULL),
207 python_ptr::keep_count);
211 axistags = other.axistags;
215 PyAxisTags(
int ndim, std::string
const & order =
"")
218 axistags = detail::defaultAxistags(ndim, order);
220 axistags = detail::emptyAxistags(ndim);
226 ? PySequence_Length(axistags)
230 long channelIndex(
long defaultVal)
const 232 return pythonGetAttr(axistags,
"channelIndex", defaultVal);
235 long channelIndex()
const 237 return channelIndex(size());
240 bool hasChannelAxis()
const 242 return channelIndex() != size();
245 long innerNonchannelIndex(
long defaultVal)
const 247 return pythonGetAttr(axistags,
"innerNonchannelIndex", defaultVal);
250 long innerNonchannelIndex()
const 252 return innerNonchannelIndex(size());
255 void setChannelDescription(std::string
const & description)
259 python_ptr d(PyString_FromString(description.c_str()), python_ptr::keep_count);
260 python_ptr func(PyString_FromString(
"setChannelDescription"), python_ptr::keep_count);
261 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), d.get(), NULL),
262 python_ptr::keep_count);
263 pythonToCppException(res);
266 double resolution(
long index)
270 python_ptr func(PyString_FromString(
"resolution"), python_ptr::keep_count);
271 python_ptr i(PyInt_FromLong(index), python_ptr::keep_count);
272 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), i.get(), NULL),
273 python_ptr::keep_count);
274 pythonToCppException(res);
275 if(!PyFloat_Check(res))
277 PyErr_SetString(PyExc_TypeError,
"AxisTags.resolution() did not return float.");
278 pythonToCppException(
false);
280 return PyFloat_AsDouble(res);
283 void setResolution(
long index,
double resolution)
287 python_ptr func(PyString_FromString(
"setResolution"), python_ptr::keep_count);
288 python_ptr i(PyInt_FromLong(index), python_ptr::keep_count);
289 python_ptr r(PyFloat_FromDouble(resolution), python_ptr::keep_count);
290 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), i.get(), r.get(), NULL),
291 python_ptr::keep_count);
292 pythonToCppException(res);
295 void scaleResolution(
long index,
double factor)
299 python_ptr func(PyString_FromString(
"scaleResolution"), python_ptr::keep_count);
300 python_ptr i(PyInt_FromLong(index), python_ptr::keep_count);
301 python_ptr f(PyFloat_FromDouble(factor), python_ptr::keep_count);
302 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), i.get(), f.get(), NULL),
303 python_ptr::keep_count);
304 pythonToCppException(res);
307 void toFrequencyDomain(
long index,
int size,
int sign = 1)
311 python_ptr func(
sign == 1
312 ? PyString_FromString(
"toFrequencyDomain")
313 : PyString_FromString(
"fromFrequencyDomain"),
314 python_ptr::keep_count);
315 python_ptr i(PyInt_FromLong(index), python_ptr::keep_count);
316 python_ptr s(PyInt_FromLong(size), python_ptr::keep_count);
317 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), i.get(), s.get(), NULL),
318 python_ptr::keep_count);
319 pythonToCppException(res);
322 void fromFrequencyDomain(
long index,
int size)
324 toFrequencyDomain(index, size, -1);
327 ArrayVector<npy_intp>
328 permutationToNormalOrder(
bool ignoreErrors =
false)
const 330 ArrayVector<npy_intp> permute;
331 detail::getAxisPermutationImpl(permute, axistags,
"permutationToNormalOrder", ignoreErrors);
335 ArrayVector<npy_intp>
336 permutationToNormalOrder(AxisInfo::AxisType types,
bool ignoreErrors =
false)
const 338 ArrayVector<npy_intp> permute;
339 detail::getAxisPermutationImpl(permute, axistags,
340 "permutationToNormalOrder", types, ignoreErrors);
344 ArrayVector<npy_intp>
345 permutationFromNormalOrder(
bool ignoreErrors =
false)
const 347 ArrayVector<npy_intp> permute;
348 detail::getAxisPermutationImpl(permute, axistags,
349 "permutationFromNormalOrder", ignoreErrors);
353 ArrayVector<npy_intp>
354 permutationFromNormalOrder(AxisInfo::AxisType types,
bool ignoreErrors =
false)
const 356 ArrayVector<npy_intp> permute;
357 detail::getAxisPermutationImpl(permute, axistags,
358 "permutationFromNormalOrder", types, ignoreErrors);
362 void dropChannelAxis()
366 python_ptr func(PyString_FromString(
"dropChannelAxis"),
367 python_ptr::keep_count);
368 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), NULL),
369 python_ptr::keep_count);
370 pythonToCppException(res);
373 void insertChannelAxis()
377 python_ptr func(PyString_FromString(
"insertChannelAxis"),
378 python_ptr::keep_count);
379 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), NULL),
380 python_ptr::keep_count);
381 pythonToCppException(res);
386 return axistags.get();
389 bool operator!()
const 404 enum ChannelAxis { first, last, none };
406 ArrayVector<npy_intp> shape, original_shape;
408 ChannelAxis channelAxis;
409 std::string channelDescription;
417 template <
class U,
int N>
418 TaggedShape(TinyVector<U, N>
const & sh, PyAxisTags tags)
419 : shape(sh.begin(), sh.end()),
420 original_shape(sh.begin(), sh.end()),
426 TaggedShape(ArrayVector<T>
const & sh, PyAxisTags tags)
427 : shape(sh.begin(), sh.end()),
428 original_shape(sh.begin(), sh.end()),
433 template <
class U,
int N>
434 explicit TaggedShape(TinyVector<U, N>
const & sh)
435 : shape(sh.begin(), sh.end()),
436 original_shape(sh.begin(), sh.end()),
441 explicit TaggedShape(ArrayVector<T>
const & sh)
442 : shape(sh.begin(), sh.end()),
443 original_shape(sh.begin(), sh.end()),
447 template <
class U,
int N>
448 TaggedShape & resize(TinyVector<U, N>
const & sh)
450 int start = channelAxis == first
453 stop = channelAxis == last
457 vigra_precondition(N == stop - start || size() == 0,
458 "TaggedShape.resize(): size mismatch.");
463 for(
int k=0; k<N; ++k)
464 shape[k+start] = sh[k];
471 return resize(TinyVector<MultiArrayIndex, 1>(v1));
476 return resize(TinyVector<MultiArrayIndex, 2>(v1, v2));
481 return resize(TinyVector<MultiArrayIndex, 3>(v1, v2, v3));
487 return resize(TinyVector<MultiArrayIndex, 4>(v1, v2, v3, v4));
490 npy_intp & operator[](
int i)
495 npy_intp operator[](
int i)
const 500 unsigned int size()
const 507 int start = channelAxis == first
510 stop = channelAxis == last
513 for(
int k=start; k<stop; ++k)
526 int start = channelAxis == first
529 stop = channelAxis == last
532 for(
int k=start; k<stop; ++k)
538 void rotateToNormalOrder()
540 if(axistags && channelAxis == last)
542 int ndim = (int)size();
544 npy_intp channelCount = shape[ndim-1];
545 for(
int k=ndim-1; k>0; --k)
546 shape[k] = shape[k-1];
547 shape[0] = channelCount;
549 channelCount = original_shape[ndim-1];
550 for(
int k=ndim-1; k>0; --k)
551 original_shape[k] = original_shape[k-1];
552 original_shape[0] = channelCount;
558 TaggedShape & setChannelDescription(std::string
const & description)
562 channelDescription = description;
566 TaggedShape & setChannelIndexLast()
574 template <
class U,
int N>
575 TaggedShape & transposeShape(TinyVector<U, N>
const & p)
577 int ntags = axistags.size();
578 ArrayVector<npy_intp> permute = axistags.permutationToNormalOrder();
580 int tstart = (axistags.channelIndex(ntags) < ntags)
583 int sstart = (channelAxis == first)
586 int ndim = ntags - tstart;
588 vigra_precondition(N == ndim,
589 "TaggedShape.transposeShape(): size mismatch.");
591 PyAxisTags newAxistags(axistags.axistags);
592 for(
int k=0; k<ndim; ++k)
594 original_shape[k+sstart] = shape[p[k]+sstart];
595 newAxistags.setResolution(permute[k+tstart], axistags.resolution(permute[p[k]+tstart]));
597 shape = original_shape;
598 axistags = newAxistags;
603 TaggedShape & toFrequencyDomain(
int sign = 1)
605 int ntags = axistags.size();
607 ArrayVector<npy_intp> permute = axistags.permutationToNormalOrder();
609 int tstart = (axistags.channelIndex(ntags) < ntags)
612 int sstart = (channelAxis == first)
615 int send = (channelAxis == last)
618 int size = send - sstart;
620 for(
int k=0; k<size; ++k)
622 axistags.toFrequencyDomain(permute[k+tstart], shape[k+sstart],
sign);
628 TaggedShape & fromFrequencyDomain()
630 return toFrequencyDomain(-1);
633 bool compatible(TaggedShape
const & other)
const 635 if(channelCount() != other.channelCount())
638 int start = channelAxis == first
641 stop = channelAxis == last
644 int ostart = other.channelAxis == first
647 ostop = other.channelAxis == last
648 ? (int)other.size()-1
651 int len = stop - start;
652 if(len != ostop - ostart)
655 for(
int k=0; k<len; ++k)
656 if(shape[k+start] != other.shape[k+ostart])
661 TaggedShape & setChannelCount(
int count)
672 shape.erase(shape.begin());
673 original_shape.erase(original_shape.begin());
680 shape[size()-1] = count;
685 original_shape.pop_back();
692 shape.push_back(count);
693 original_shape.push_back(count);
701 int channelCount()
const 708 return shape[size()-1];
716 void scaleAxisResolution(TaggedShape & tagged_shape)
718 if(tagged_shape.size() != tagged_shape.original_shape.size())
721 int ntags = tagged_shape.axistags.size();
723 ArrayVector<npy_intp> permute = tagged_shape.axistags.permutationToNormalOrder();
725 int tstart = (tagged_shape.axistags.channelIndex(ntags) < ntags)
728 int sstart = (tagged_shape.channelAxis == TaggedShape::first)
731 int size = (int)tagged_shape.size() - sstart;
733 for(
int k=0; k<size; ++k)
736 if(tagged_shape.shape[sk] == tagged_shape.original_shape[sk])
738 double factor = (tagged_shape.original_shape[sk] - 1.0) / (tagged_shape.shape[sk] - 1.0);
739 tagged_shape.axistags.scaleResolution(permute[k+tstart], factor);
744 void unifyTaggedShapeSize(TaggedShape & tagged_shape)
746 PyAxisTags axistags = tagged_shape.axistags;
747 ArrayVector<npy_intp> & shape = tagged_shape.shape;
749 int ndim = (int)shape.size();
750 int ntags = axistags.size();
752 long channelIndex = axistags.channelIndex();
754 if(tagged_shape.channelAxis == TaggedShape::none)
757 if(channelIndex == ntags)
761 vigra_precondition(ndim == ntags,
762 "constructArray(): size mismatch between shape and axistags.");
772 axistags.dropChannelAxis();
776 vigra_precondition(ndim == ntags,
777 "constructArray(): size mismatch between shape and axistags.");
784 if(channelIndex == ntags)
788 vigra_precondition(ndim == ntags+1,
789 "constructArray(): size mismatch between shape and axistags.");
795 shape.erase(shape.begin());
802 axistags.insertChannelAxis();
809 vigra_precondition(ndim == ntags,
810 "constructArray(): size mismatch between shape and axistags.");
816 ArrayVector<npy_intp> finalizeTaggedShape(TaggedShape & tagged_shape)
818 if(tagged_shape.axistags)
820 tagged_shape.rotateToNormalOrder();
824 scaleAxisResolution(tagged_shape);
828 unifyTaggedShapeSize(tagged_shape);
830 if(tagged_shape.channelDescription !=
"")
831 tagged_shape.axistags.setChannelDescription(tagged_shape.channelDescription);
833 return tagged_shape.shape;
838 #endif // VIGRA_NUMPY_ARRAY_TAGGEDSHAPE_HXX std::ptrdiff_t MultiArrayIndex
Definition: multi_shape.hxx:55
Definition: accessor.hxx:43
FFTWComplex< R > & operator-=(FFTWComplex< R > &a, const FFTWComplex< R > &b)
subtract-assignment
Definition: fftw3.hxx:867
FFTWComplex< R > & operator+=(FFTWComplex< R > &a, const FFTWComplex< R > &b)
add-assignment
Definition: fftw3.hxx:859
FFTWComplex< R > & operator*=(FFTWComplex< R > &a, const FFTWComplex< R > &b)
multiply-assignment
Definition: fftw3.hxx:875
T sign(T t)
The sign function.
Definition: mathutil.hxx:553