diff --git a/ccdproc/image_collection.py b/ccdproc/image_collection.py index ff58c468..095ab99c 100644 --- a/ccdproc/image_collection.py +++ b/ccdproc/image_collection.py @@ -78,6 +78,10 @@ class ImageFileCollection: The extension from which the header and data will be read in all files.Default is ``0``. + extensions: list + A list of FITS extensions to search for. The default is ``['fit', + 'fits', 'fts']``. + Raises ------ ValueError @@ -87,7 +91,8 @@ class ImageFileCollection: def __init__(self, location=None, keywords=None, find_fits_by_reading=False, - filenames=None, glob_include=None, glob_exclude=None, ext=0): + filenames=None, glob_include=None, glob_exclude=None, ext=0, + extensions=None): # Include or exclude files from the collection based on glob pattern # matching - has to go above call to _get_files() if glob_exclude is not None: @@ -103,6 +108,17 @@ def __init__(self, location=None, keywords=None, else: self._location = '' + # Set file name extensions + # Do our best to keep the file extensions immutable + if extensions is not None: + if isinstance(extensions, str): + # Comma at the end to force it to be a tuple + self._file_extensions = (extensions,) + else: + self._file_extensions = tuple(extensions) + else: + self._file_extensions = tuple(_recognized_fits_file_extensions) + self._find_fits_by_reading = find_fits_by_reading self._filenames = filenames @@ -287,11 +303,19 @@ def glob_exclude(self): @property def ext(self): """ - str or int, The extension from which the header and data will + str or int, The FITS extension from which the header and data will be read in all files. """ return self._ext + @property + def file_extensions(self): + """ + List of file name extensions to match when populating or refreshing + the ``ImageFileCollection``. + """ + return self._file_extensions + def values(self, keyword, unique=False): """ List of values for a keyword. @@ -747,18 +771,13 @@ def _find_keywords_by_values(self, **kwd): self.summary['file'].mask = ma.nomask self.summary['file'].mask[~matches] = True - def _fits_files_in_directory(self, extensions=None, + def _fits_files_in_directory(self, compressed=True): """ Get names of FITS files in directory, based on filename extension. Parameters ---------- - extensions : list of str or None, optional - List of filename extensions that are FITS files. Default is - ``['fit', 'fits', 'fts']``. - Default is ``None``. - compressed : bool, optional If ``True``, compressed files should be included in the list (e.g. `.fits.gz`). @@ -769,10 +788,12 @@ def _fits_files_in_directory(self, extensions=None, list *Names* of the files (with extension), not the full pathname. """ + # Force a copy of the extensions to avoid endless combinations of + # compression extensions. + full_extensions = list(self.file_extensions) - full_extensions = extensions or list(_recognized_fits_file_extensions) - - # The common compressed fits image .fz is supported using ext=1 when calling ImageFileCollection + # The common compressed fits image .fz is supported using ext=1 when + # calling ImageFileCollection if compressed: for comp in ['.gz', '.bz2', '.Z', '.zip', '.fz']: with_comp = [extension + comp for extension in full_extensions] diff --git a/ccdproc/tests/test_image_collection.py b/ccdproc/tests/test_image_collection.py index 65cc4cc5..8111da34 100644 --- a/ccdproc/tests/test_image_collection.py +++ b/ccdproc/tests/test_image_collection.py @@ -1100,3 +1100,19 @@ def test_filtered_collection_with_no_files(self, triage_setup): ifc = ImageFileCollection(triage_setup.test_dir) ifc_no_files = ifc.filter(object='really fake object') + + @pytest.mark.parametrize('extensions', ('flubber', ['flubber'])) + def test_user_specified_file_extensions(self, tmp_path, extensions): + # Test for #727, allowing user to specify fits + # extensions + ccd = CCDData(data=np.zeros([10, 10]), unit='adu') + num_files = 4 + if len(extensions) == 1: + extension = extensions[0] + else: + extension = extensions + # Explicitly give the format for writing + _ = [ccd.write(tmp_path / f"ccd_{i}.{extension}", format='fits') + for i in range(num_files)] + ifc = ImageFileCollection(tmp_path, extensions=extensions) + assert len(ifc.summary) == num_files