[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/numpy_array_taggedshape.hxx VIGRA

00001 /************************************************************************/
00002 /*                                                                      */
00003 /*       Copyright 2009 by Ullrich Koethe and Hans Meine                */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 
00036 #ifndef VIGRA_NUMPY_ARRAY_TAGGEDSHAPE_HXX
00037 #define VIGRA_NUMPY_ARRAY_TAGGEDSHAPE_HXX
00038 
00039 #include <string>
00040 #include "array_vector.hxx"
00041 #include "python_utility.hxx"
00042 #include "axistags.hxx"
00043 
00044 namespace vigra {
00045 
00046 namespace detail {
00047 
00048 inline
00049 python_ptr getArrayTypeObject()
00050 {
00051     python_ptr arraytype((PyObject*)&PyArray_Type);
00052     python_ptr vigra(PyImport_ImportModule("vigra"));
00053     if(!vigra)
00054         PyErr_Clear();
00055     return pythonGetAttr(vigra, "standardArrayType", arraytype);
00056 }
00057 
00058 inline 
00059 std::string defaultOrder(std::string defaultValue = "C")
00060 {
00061     python_ptr arraytype = getArrayTypeObject();
00062     return pythonGetAttr(arraytype, "defaultOrder", defaultValue);
00063 }
00064 
00065 inline 
00066 python_ptr defaultAxistags(int ndim, std::string order = "")
00067 {
00068     if(order == "")
00069         order = defaultOrder();
00070     python_ptr arraytype = getArrayTypeObject();
00071     python_ptr func(PyString_FromString("defaultAxistags"), python_ptr::keep_count);
00072     python_ptr d(PyInt_FromLong(ndim), python_ptr::keep_count);
00073     python_ptr o(PyString_FromString(order.c_str()), python_ptr::keep_count);
00074     python_ptr axistags(PyObject_CallMethodObjArgs(arraytype, func.get(), d.get(), o.get(), NULL),
00075                         python_ptr::keep_count);
00076     if(axistags)
00077         return axistags;
00078     PyErr_Clear();
00079     return python_ptr();
00080 }
00081 
00082 inline 
00083 python_ptr emptyAxistags(int ndim)
00084 {
00085     python_ptr arraytype = getArrayTypeObject();
00086     python_ptr func(PyString_FromString("_empty_axistags"), python_ptr::keep_count);
00087     python_ptr d(PyInt_FromLong(ndim), python_ptr::keep_count);
00088     python_ptr axistags(PyObject_CallMethodObjArgs(arraytype, func.get(), d.get(), NULL),
00089                         python_ptr::keep_count);
00090     if(axistags)
00091         return axistags;
00092     PyErr_Clear();
00093     return python_ptr();
00094 }
00095 
00096 inline 
00097 void
00098 getAxisPermutationImpl(ArrayVector<npy_intp> & permute,
00099                        python_ptr object, const char * name, 
00100                        AxisInfo::AxisType type, bool ignoreErrors)
00101 {
00102     python_ptr func(PyString_FromString(name), python_ptr::keep_count);
00103     python_ptr t(PyInt_FromLong((long)type), python_ptr::keep_count);
00104     python_ptr permutation(PyObject_CallMethodObjArgs(object, func.get(), t.get(), NULL), 
00105                            python_ptr::keep_count);
00106     if(!permutation && ignoreErrors)
00107     {
00108         PyErr_Clear();
00109         return;
00110     }
00111     pythonToCppException(permutation);
00112     
00113     if(!PySequence_Check(permutation))
00114     {
00115         if(ignoreErrors)
00116             return;
00117         std::string message = std::string(name) + "() did not return a sequence.";
00118         PyErr_SetString(PyExc_ValueError, message.c_str());
00119         pythonToCppException(false);
00120     }
00121         
00122     ArrayVector<npy_intp> res(PySequence_Length(permutation));
00123     for(int k=0; k<(int)res.size(); ++k)
00124     {
00125         python_ptr i(PySequence_GetItem(permutation, k), python_ptr::keep_count);
00126         if(!PyInt_Check(i))
00127         {
00128             if(ignoreErrors)
00129                 return;
00130             std::string message = std::string(name) + "() did not return a sequence of int.";
00131             PyErr_SetString(PyExc_ValueError, message.c_str());
00132             pythonToCppException(false);
00133         }
00134         res[k] = PyInt_AsLong(i);
00135     }
00136     res.swap(permute);
00137 }
00138 
00139 inline 
00140 void
00141 getAxisPermutationImpl(ArrayVector<npy_intp> & permute,
00142                        python_ptr object, const char * name, bool ignoreErrors)
00143 {
00144     getAxisPermutationImpl(permute, object, name, AxisInfo::AllAxes, ignoreErrors);
00145 }
00146 
00147 } // namespace detail
00148 
00149 /********************************************************/
00150 /*                                                      */
00151 /*                     PyAxisTags                       */
00152 /*                                                      */
00153 /********************************************************/
00154 
00155 // FIXME: right now, we implement this class using the standard
00156 //        Python C-API only. It would be easier and more efficient 
00157 //        to use boost::python here, but it would cause NumpyArray
00158 //        to depend on boost, making it more difficult to use
00159 //        NumpyArray in connection with other glue code generators.
00160 class PyAxisTags
00161 {
00162   public:
00163     typedef PyObject * pointer;
00164     
00165     python_ptr axistags;
00166     
00167     PyAxisTags(python_ptr tags = python_ptr(), bool createCopy = false)
00168     {
00169         if(!tags)
00170             return;
00171         // FIXME: do a more elaborate type check here?
00172         if(!PySequence_Check(tags))
00173         {
00174             PyErr_SetString(PyExc_TypeError, 
00175                            "PyAxisTags(tags): tags argument must have type 'AxisTags'.");
00176             pythonToCppException(false);
00177         }
00178         else if(PySequence_Length(tags) == 0)
00179         {
00180             return;
00181         }
00182         
00183         if(createCopy)
00184         {
00185             python_ptr func(PyString_FromString("__copy__"), python_ptr::keep_count);
00186             axistags = python_ptr(PyObject_CallMethodObjArgs(tags, func.get(), NULL), 
00187                                   python_ptr::keep_count);
00188         }
00189         else
00190         {
00191             axistags = tags;
00192         }
00193     }
00194     
00195     PyAxisTags(PyAxisTags const & other, bool createCopy = false)
00196     {
00197         if(!other.axistags)
00198             return;
00199         if(createCopy)
00200         {
00201             python_ptr func(PyString_FromString("__copy__"), python_ptr::keep_count);
00202             axistags = python_ptr(PyObject_CallMethodObjArgs(other.axistags, func.get(), NULL), 
00203                                   python_ptr::keep_count);
00204         }
00205         else
00206         {
00207             axistags = other.axistags;
00208         }
00209     }
00210     
00211     PyAxisTags(int ndim, std::string const & order = "")
00212     {
00213         if(order != "")
00214             axistags = detail::defaultAxistags(ndim, order);
00215         else
00216             axistags = detail::emptyAxistags(ndim);
00217     }
00218     
00219     long size() const
00220     {
00221         return axistags
00222                    ? PySequence_Length(axistags)
00223                    : 0;
00224     }
00225     
00226     long channelIndex(long defaultVal) const
00227     {
00228         return pythonGetAttr(axistags, "channelIndex", defaultVal);
00229     }
00230 
00231     long channelIndex() const
00232     {
00233         return channelIndex(size());
00234     }
00235 
00236     bool hasChannelAxis() const
00237     {
00238         return channelIndex() != size();
00239     }
00240     
00241     long innerNonchannelIndex(long defaultVal) const
00242     {
00243         return pythonGetAttr(axistags, "innerNonchannelIndex", defaultVal);
00244     }
00245 
00246     long innerNonchannelIndex() const
00247     {
00248         return innerNonchannelIndex(size());
00249     }
00250 
00251     void setChannelDescription(std::string const & description)
00252     {
00253         if(!axistags)
00254             return;
00255         python_ptr d(PyString_FromString(description.c_str()), python_ptr::keep_count);
00256         python_ptr func(PyString_FromString("setChannelDescription"), python_ptr::keep_count);
00257         python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), d.get(), NULL), 
00258                        python_ptr::keep_count);
00259         pythonToCppException(res);
00260     }
00261 
00262     double resolution(long index)
00263     {
00264         if(!axistags)
00265             return 0.0;
00266         python_ptr func(PyString_FromString("resolution"), python_ptr::keep_count);
00267         python_ptr i(PyInt_FromLong(index), python_ptr::keep_count);
00268         python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), i.get(), NULL), 
00269                        python_ptr::keep_count);
00270         pythonToCppException(res);
00271         if(!PyFloat_Check(res))
00272         {
00273             PyErr_SetString(PyExc_TypeError, "AxisTags.resolution() did not return float.");
00274             pythonToCppException(false);
00275         }
00276         return PyFloat_AsDouble(res);
00277     }
00278  
00279     void setResolution(long index, double resolution)
00280     {
00281         if(!axistags)
00282             return;
00283         python_ptr func(PyString_FromString("setResolution"), python_ptr::keep_count);
00284         python_ptr i(PyInt_FromLong(index), python_ptr::keep_count);
00285         python_ptr r(PyFloat_FromDouble(resolution), python_ptr::keep_count);
00286         python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), i.get(), r.get(), NULL), 
00287                        python_ptr::keep_count);
00288         pythonToCppException(res);
00289     }
00290  
00291     void scaleResolution(long index, double factor)
00292     {
00293         if(!axistags)
00294             return;
00295         python_ptr func(PyString_FromString("scaleResolution"), python_ptr::keep_count);
00296         python_ptr i(PyInt_FromLong(index), python_ptr::keep_count);
00297         python_ptr f(PyFloat_FromDouble(factor), python_ptr::keep_count);
00298         python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), i.get(), f.get(), NULL), 
00299                        python_ptr::keep_count);
00300         pythonToCppException(res);
00301     }
00302  
00303     void toFrequencyDomain(long index, int size, int sign = 1)
00304     {
00305         if(!axistags)
00306             return;
00307         python_ptr func(sign == 1
00308                            ? PyString_FromString("toFrequencyDomain")
00309                            : PyString_FromString("fromFrequencyDomain"), 
00310                         python_ptr::keep_count);
00311         python_ptr i(PyInt_FromLong(index), python_ptr::keep_count);
00312         python_ptr s(PyInt_FromLong(size), python_ptr::keep_count);
00313         python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), i.get(), s.get(), NULL), 
00314                        python_ptr::keep_count);
00315         pythonToCppException(res);
00316     }
00317  
00318     void fromFrequencyDomain(long index, int size)
00319     {
00320         toFrequencyDomain(index, size, -1);
00321     }
00322  
00323     ArrayVector<npy_intp> 
00324     permutationToNormalOrder(bool ignoreErrors = false) const
00325     {
00326         ArrayVector<npy_intp> permute;
00327         detail::getAxisPermutationImpl(permute, axistags, "permutationToNormalOrder", ignoreErrors);
00328         return permute;
00329     }
00330 
00331     ArrayVector<npy_intp> 
00332     permutationToNormalOrder(AxisInfo::AxisType types, bool ignoreErrors = false) const
00333     {
00334         ArrayVector<npy_intp> permute;
00335         detail::getAxisPermutationImpl(permute, axistags, 
00336                                             "permutationToNormalOrder", types, ignoreErrors);
00337         return permute;
00338     }
00339 
00340     ArrayVector<npy_intp> 
00341     permutationFromNormalOrder(bool ignoreErrors = false) const
00342     {
00343         ArrayVector<npy_intp> permute;
00344         detail::getAxisPermutationImpl(permute, axistags, 
00345                                        "permutationFromNormalOrder", ignoreErrors);
00346         return permute;
00347     }
00348     
00349     ArrayVector<npy_intp> 
00350     permutationFromNormalOrder(AxisInfo::AxisType types, bool ignoreErrors = false) const
00351     {
00352         ArrayVector<npy_intp> permute;
00353         detail::getAxisPermutationImpl(permute, axistags, 
00354                                        "permutationFromNormalOrder", types, ignoreErrors);
00355         return permute;
00356     }
00357     
00358     void dropChannelAxis()
00359     {
00360         if(!axistags)
00361             return;
00362         python_ptr func(PyString_FromString("dropChannelAxis"), 
00363                                python_ptr::keep_count);
00364         python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), NULL), 
00365                        python_ptr::keep_count);
00366         pythonToCppException(res);
00367     }
00368     
00369     void insertChannelAxis()
00370     {
00371         if(!axistags)
00372             return;
00373         python_ptr func(PyString_FromString("insertChannelAxis"), 
00374                                python_ptr::keep_count);
00375         python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), NULL), 
00376                        python_ptr::keep_count);
00377         pythonToCppException(res);
00378     }
00379     
00380     operator pointer()
00381     {
00382         return axistags.get();
00383     }
00384 
00385     bool operator!() const
00386     {
00387         return !axistags;
00388     }
00389 };
00390 
00391 /********************************************************/
00392 /*                                                      */
00393 /*                     TaggedShape                      */
00394 /*                                                      */
00395 /********************************************************/
00396 
00397 class TaggedShape
00398 {
00399   public:
00400     enum ChannelAxis { first, last, none };
00401     
00402     ArrayVector<npy_intp> shape, original_shape;
00403     PyAxisTags axistags;
00404     ChannelAxis channelAxis;
00405     std::string channelDescription;
00406     
00407     explicit TaggedShape(MultiArrayIndex size)
00408     : shape(size),
00409       axistags(size),
00410       channelAxis(none)
00411     {}
00412     
00413     template <class U, int N>
00414     TaggedShape(TinyVector<U, N> const & sh, PyAxisTags tags)
00415     : shape(sh.begin(), sh.end()),
00416       original_shape(sh.begin(), sh.end()),
00417       axistags(tags),
00418       channelAxis(none)
00419     {}
00420     
00421     template <class T>
00422     TaggedShape(ArrayVector<T> const & sh, PyAxisTags tags)
00423     : shape(sh.begin(), sh.end()),
00424       original_shape(sh.begin(), sh.end()),
00425       axistags(tags),
00426       channelAxis(none)
00427     {}
00428     
00429     template <class U, int N>
00430     explicit TaggedShape(TinyVector<U, N> const & sh)
00431     : shape(sh.begin(), sh.end()),
00432       original_shape(sh.begin(), sh.end()),
00433       channelAxis(none)
00434     {}
00435     
00436     template <class T>
00437     explicit TaggedShape(ArrayVector<T> const & sh)
00438     : shape(sh.begin(), sh.end()),
00439       original_shape(sh.begin(), sh.end()),
00440       channelAxis(none)
00441     {}
00442     
00443     template <class U, int N>
00444     TaggedShape & resize(TinyVector<U, N> const & sh)
00445     {
00446         int start = channelAxis == first
00447                         ? 1
00448                         : 0, 
00449             stop = channelAxis == last
00450                         ? (int)size()-1
00451                         : (int)size();
00452                         
00453         vigra_precondition(N == stop - start || size() == 0,
00454              "TaggedShape.resize(): size mismatch.");
00455              
00456         if(size() == 0)
00457             shape.resize(N);
00458         
00459         for(int k=0; k<N; ++k)
00460             shape[k+start] = sh[k];
00461             
00462         return *this;
00463     }
00464     
00465     TaggedShape & resize(MultiArrayIndex v1)
00466     {
00467         return resize(TinyVector<MultiArrayIndex, 1>(v1));
00468     }
00469     
00470     TaggedShape & resize(MultiArrayIndex v1, MultiArrayIndex v2)
00471     {
00472         return resize(TinyVector<MultiArrayIndex, 2>(v1, v2));
00473     }
00474     
00475     TaggedShape & resize(MultiArrayIndex v1, MultiArrayIndex v2, MultiArrayIndex v3)
00476     {
00477         return resize(TinyVector<MultiArrayIndex, 3>(v1, v2, v3));
00478     }
00479     
00480     TaggedShape & resize(MultiArrayIndex v1, MultiArrayIndex v2, 
00481                          MultiArrayIndex v3, MultiArrayIndex v4)
00482     {
00483         return resize(TinyVector<MultiArrayIndex, 4>(v1, v2, v3, v4));
00484     }
00485     
00486     npy_intp & operator[](int i)
00487     {
00488         return shape[i];
00489     }
00490     
00491     npy_intp operator[](int i) const
00492     {
00493         return shape[i];
00494     }
00495     
00496     unsigned int size() const
00497     {
00498         return shape.size();
00499     }
00500     
00501     TaggedShape & operator+=(int v)
00502     {
00503         int start = channelAxis == first
00504                         ? 1
00505                         : 0, 
00506             stop = channelAxis == last
00507                         ? (int)size()-1
00508                         : (int)size();
00509         for(int k=start; k<stop; ++k)
00510             shape[k] += v;
00511             
00512         return *this;
00513     }
00514     
00515     TaggedShape & operator-=(int v)
00516     {
00517         return operator+=(-v);
00518     }
00519     
00520     TaggedShape & operator*=(int factor)
00521     {
00522         int start = channelAxis == first
00523                         ? 1
00524                         : 0, 
00525             stop = channelAxis == last
00526                         ? (int)size()-1
00527                         : (int)size();
00528         for(int k=start; k<stop; ++k)
00529             shape[k] *= factor;
00530             
00531         return *this;
00532     }
00533     
00534     void rotateToNormalOrder()
00535     {
00536         if(axistags && channelAxis == last)
00537         {
00538             int ndim = (int)size();
00539             
00540             npy_intp channelCount = shape[ndim-1];            
00541             for(int k=ndim-1; k>0; --k)
00542                 shape[k] = shape[k-1];
00543             shape[0] = channelCount;
00544             
00545             channelCount = original_shape[ndim-1];            
00546             for(int k=ndim-1; k>0; --k)
00547                 original_shape[k] = original_shape[k-1];
00548             original_shape[0] = channelCount;
00549             
00550             channelAxis = first;
00551         }
00552     }
00553     
00554     TaggedShape & setChannelDescription(std::string const & description)
00555     {
00556         // we only remember the description here, and will actually set
00557         // it in the finalize function
00558         channelDescription = description;
00559         return *this;
00560     }
00561     
00562     TaggedShape & setChannelIndexLast()
00563     {
00564         // FIXME: add some checks?
00565         channelAxis = last;
00566         return *this;
00567     }
00568     
00569     // transposeShape() means: only shape and resolution are transposed, not the axis keys
00570     template <class U, int N>
00571     TaggedShape & transposeShape(TinyVector<U, N> const & p)
00572     {
00573         int ntags = axistags.size();
00574         ArrayVector<npy_intp> permute = axistags.permutationToNormalOrder();
00575         
00576         int tstart = (axistags.channelIndex(ntags) < ntags)
00577                         ? 1
00578                         : 0;
00579         int sstart = (channelAxis == first)
00580                         ? 1
00581                         : 0;
00582         int ndim = ntags - tstart;
00583 
00584         vigra_precondition(N == ndim,
00585              "TaggedShape.transposeShape(): size mismatch.");
00586              
00587         PyAxisTags newAxistags(axistags.axistags); // force copy
00588         for(int k=0; k<ndim; ++k)
00589         {
00590             original_shape[k+sstart] = shape[p[k]+sstart];
00591             newAxistags.setResolution(permute[k+tstart], axistags.resolution(permute[p[k]+tstart]));
00592         }
00593         shape = original_shape;
00594         axistags = newAxistags;
00595         
00596         return *this;
00597     }
00598 
00599     TaggedShape & toFrequencyDomain(int sign = 1)
00600     {
00601         int ntags = axistags.size();
00602         
00603         ArrayVector<npy_intp> permute = axistags.permutationToNormalOrder();
00604         
00605         int tstart = (axistags.channelIndex(ntags) < ntags)
00606                         ? 1
00607                         : 0;
00608         int sstart = (channelAxis == first)
00609                         ? 1
00610                         : 0;
00611         int send  = (channelAxis == last)
00612                         ? (int)size()-1
00613                         : (int)size();
00614         int size = send - sstart;
00615         
00616         for(int k=0; k<size; ++k)
00617         {
00618             axistags.toFrequencyDomain(permute[k+tstart], shape[k+sstart], sign);
00619         }
00620         
00621         return *this;
00622     }
00623 
00624     TaggedShape & fromFrequencyDomain()
00625     {
00626         return toFrequencyDomain(-1);
00627     }
00628     
00629     bool compatible(TaggedShape const & other) const
00630     {
00631         if(channelCount() != other.channelCount())
00632             return false;
00633             
00634         int start = channelAxis == first
00635                         ? 1
00636                         : 0, 
00637             stop = channelAxis == last
00638                         ? (int)size()-1
00639                         : (int)size();
00640         int ostart = other.channelAxis == first
00641                         ? 1
00642                         : 0, 
00643             ostop = other.channelAxis == last
00644                         ? (int)other.size()-1
00645                         : (int)other.size();
00646                         
00647         int len = stop - start;
00648         if(len != ostop - ostart)
00649             return false;
00650         
00651         for(int k=0; k<len; ++k)
00652             if(shape[k+start] != other.shape[k+ostart])
00653                 return false;
00654         return true;
00655     }
00656     
00657     TaggedShape & setChannelCount(int count)
00658     {
00659         switch(channelAxis)
00660         {
00661           case first:
00662             if(count > 0)
00663             {
00664                 shape[0] = count;
00665             }
00666             else
00667             {
00668                 shape.erase(shape.begin());
00669                 original_shape.erase(original_shape.begin());
00670                 channelAxis = none;
00671             }
00672             break;
00673           case last:
00674             if(count > 0)
00675             {
00676                 shape[size()-1] = count;
00677             }
00678             else
00679             {
00680                 shape.pop_back();
00681                 original_shape.pop_back();
00682                 channelAxis = none;
00683             }
00684             break;
00685           case none:
00686             if(count > 0)
00687             {
00688                 shape.push_back(count);
00689                 original_shape.push_back(count);
00690                 channelAxis = last;
00691             }
00692             break;
00693         }
00694         return *this;
00695     }
00696     
00697     int channelCount() const
00698     {
00699         switch(channelAxis)
00700         {
00701           case first:
00702             return shape[0];
00703           case last:
00704             return shape[size()-1];
00705           default:
00706             return 1;
00707         }
00708     }
00709 };
00710 
00711 inline 
00712 void scaleAxisResolution(TaggedShape & tagged_shape)
00713 {
00714     if(tagged_shape.size() != tagged_shape.original_shape.size())
00715         return;
00716     
00717     int ntags = tagged_shape.axistags.size();
00718     
00719     ArrayVector<npy_intp> permute = tagged_shape.axistags.permutationToNormalOrder();
00720     
00721     int tstart = (tagged_shape.axistags.channelIndex(ntags) < ntags)
00722                     ? 1
00723                     : 0;
00724     int sstart = (tagged_shape.channelAxis == TaggedShape::first)
00725                     ? 1
00726                     : 0;
00727     int size = (int)tagged_shape.size() - sstart;
00728     
00729     for(int k=0; k<size; ++k)
00730     {
00731         int sk = k + sstart;
00732         if(tagged_shape.shape[sk] == tagged_shape.original_shape[sk])
00733             continue;
00734         double factor = (tagged_shape.original_shape[sk] - 1.0) / (tagged_shape.shape[sk] - 1.0);
00735         tagged_shape.axistags.scaleResolution(permute[k+tstart], factor);
00736     }
00737 }
00738 
00739 inline 
00740 void unifyTaggedShapeSize(TaggedShape & tagged_shape)
00741 {
00742     PyAxisTags axistags = tagged_shape.axistags;
00743     ArrayVector<npy_intp> & shape = tagged_shape.shape;
00744 
00745     int ndim = (int)shape.size();
00746     int ntags = axistags.size();
00747     
00748     long channelIndex = axistags.channelIndex();
00749 
00750     if(tagged_shape.channelAxis == TaggedShape::none)
00751     {
00752         // shape has no channel axis
00753         if(channelIndex == ntags)
00754         {
00755             // std::cerr << "branch (shape, axitags) 0 0\n";
00756             // axistags have no channel axis either => sizes should match
00757             vigra_precondition(ndim == ntags,
00758                  "constructArray(): size mismatch between shape and axistags.");
00759         }
00760         else
00761         {
00762             // std::cerr << "branch (shape, axitags) 0 1\n";
00763             if(ndim+1 == ntags)
00764             {
00765                 // std::cerr << "   drop channel axis\n";
00766                 // axistags have one additional element => drop the channel tag
00767                 // FIXME: would it be cleaner to make this an error ?
00768                 axistags.dropChannelAxis();
00769             }
00770             else
00771             {
00772                 vigra_precondition(ndim == ntags,
00773                      "constructArray(): size mismatch between shape and axistags.");
00774             }
00775         }
00776     }
00777     else
00778     {
00779         // shape has a channel axis
00780         if(channelIndex == ntags)
00781         {
00782             // std::cerr << "branch (shape, axitags) 1 0\n";
00783             // axistags have no channel axis => should be one element shorter
00784             vigra_precondition(ndim == ntags+1,
00785                  "constructArray(): size mismatch between shape and axistags.");
00786                  
00787             if(shape[0] == 1)
00788             {
00789                 // std::cerr << "   drop channel axis\n";
00790                 // we have a singleband image => drop the channel axis
00791                 shape.erase(shape.begin());
00792                 ndim -= 1;
00793             }
00794             else
00795             {
00796                 // std::cerr << "   insert channel axis\n";
00797                 // we have a multiband image => add a channel tag
00798                 axistags.insertChannelAxis();
00799             }
00800         }
00801         else
00802         {
00803             // std::cerr << "branch (shape, axitags) 1 1\n";
00804             // axistags have channel axis => sizes should match
00805             vigra_precondition(ndim == ntags,
00806                  "constructArray(): size mismatch between shape and axistags.");
00807         }
00808     }
00809 }
00810 
00811 inline
00812 ArrayVector<npy_intp> finalizeTaggedShape(TaggedShape & tagged_shape)
00813 {
00814     if(tagged_shape.axistags)
00815     {
00816         tagged_shape.rotateToNormalOrder();
00817     
00818         // we assume here that the axistag object belongs to the array to be created
00819         // so that we can freely edit it
00820         scaleAxisResolution(tagged_shape);
00821             
00822         // this must be after scaleAxisResolution(), because the latter requires 
00823         // shape and original_shape to be still in sync
00824         unifyTaggedShapeSize(tagged_shape);
00825                 
00826         if(tagged_shape.channelDescription != "")
00827             tagged_shape.axistags.setChannelDescription(tagged_shape.channelDescription);
00828     }
00829     return tagged_shape.shape;
00830 }
00831 
00832 } // namespace vigra
00833 
00834 #endif // VIGRA_NUMPY_ARRAY_TAGGEDSHAPE_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.9.0 (Tue Nov 6 2012)